1 module des.mc.multitrack.model.simple; 2 3 import des.mc.multitrack.model; 4 import std.algorithm; 5 import des.mc.multitrack.model.util; 6 7 struct SimpleHeuristicParams 8 { 9 10 } 11 12 class SimpleHeuristic : Heuristic 13 { 14 SimpleHeuristicParams params; 15 16 this( SimpleHeuristicParams shp ) 17 { 18 params = shp; 19 } 20 21 Skeleton opCall( in Skeleton skel ) 22 { 23 // TODO 24 return skel; 25 } 26 } 27 28 struct SimpleClassifierParams 29 { 30 float min_point_quality=0.5; 31 float class_offset_limit=100; 32 float class_deviation_limit=200; 33 } 34 35 class SimpleClassifier : Classifier 36 { 37 SimpleClassifierParams params; 38 39 this( in SimpleClassifierParams scp = SimpleClassifierParams() ) 40 { 41 params = scp; 42 } 43 44 Skeleton[][] opCall( in Skeleton[][] skel_arr ) 45 { 46 auto red = plainArray( skel_arr ); 47 ClassifierClass[] classes; 48 foreach( skel; red ) 49 { 50 auto res = findClass( classes, skel ); 51 auto cur = processResult( classes, res[0], res[1] ); 52 cur.append( skel ); 53 } 54 return getSkeletons( classes ); 55 } 56 57 static auto findClass( ClassifierClass[] classes, in Skeleton skel ) 58 { 59 ClassifierClass fnd; 60 float[2] min_diff = [ float.max, float.max ]; 61 foreach( cls; classes ) 62 { 63 auto df = cls.diff(skel); 64 if( df[0] < min_diff[0] && df[1] < min_diff[1] ) 65 { 66 min_diff = df; 67 fnd = cls; 68 } 69 } 70 return tuple( fnd, min_diff ); 71 } 72 73 ClassifierClass processResult( ref ClassifierClass[] classes, 74 ClassifierClass cls, float[2] diff ) 75 { 76 auto ret = cls; 77 if( cls is null || diff[0] > params.class_offset_limit || 78 diff[1] > params.class_deviation_limit ) 79 { 80 ret = newClassifierClass(); 81 classes ~= ret; 82 } 83 return ret; 84 } 85 86 auto newClassifierClass() 87 { 88 return new ClassifierClass( params.min_point_quality ); 89 } 90 91 static auto getSkeletons( ClassifierClass[] classes ) 92 { 93 Skeleton[][] ret; 94 foreach( cls; classes ) 95 ret ~= cls.array; 96 return ret; 97 } 98 } 99 100 unittest 101 { 102 auto tsc = new SimpleClassifier( SimpleClassifierParams(0.5,2) ); 103 104 Skeleton[][] by_tracker; 105 by_tracker ~= getFakeSkeletons(vec3(0,0,.1)); 106 by_tracker ~= getFakeSkeletons(vec3(0,0.1,0)); 107 108 auto by_group = tsc( by_tracker ); 109 110 //printSkeletonsDArray( by_tracker ); 111 //printSkeletonsDArray( by_group ); 112 113 assert( by_group[0][0] == by_tracker[0][0] ); 114 assert( by_group[0][1] == by_tracker[1][0] ); 115 assert( by_group[1][0] == by_tracker[0][1] ); 116 assert( by_group[1][1] == by_tracker[1][1] ); 117 } 118 119 struct SimpleComplexerParams 120 { 121 // TODO: params 122 } 123 124 class SimpleComplexer : Complexer 125 { 126 SimpleComplexerParams params; 127 128 this( in SimpleComplexerParams scp = SimpleComplexerParams() ) 129 { 130 params = scp; 131 } 132 133 Skeleton[] opCall( in Skeleton[][] skels ) 134 { 135 auto min_qual = 0.5; 136 Skeleton[] result; 137 foreach( group; skels ) 138 { 139 if( group.length == 0 ) continue; 140 141 Skeleton mean = group[0]; 142 if( group.length < 2 ) { result ~= mean; continue; } 143 size_t n = 1; 144 auto mj = mean.allJoints(); 145 foreach( s; group[1 .. $] ) 146 { 147 auto sj = s.allJoints(); 148 auto hiq = new ubyte[]( sj.length ); 149 foreach( i, ref h; hiq ) 150 h = (mj[i].qual > min_qual)*2 + (sj[i].qual > min_qual); 151 auto offset = new vec3[]( sj.length ); 152 vec3 offset_exp; 153 size_t offset_exp_cnt; 154 foreach( i, h; hiq ) 155 if( h == 3 ) 156 { 157 auto buf = sj[i].pos - mj[i].pos; 158 offset[i] = buf; 159 offset_exp += buf; 160 offset_exp_cnt++; 161 } 162 offset_exp /= cast(float)offset_exp_cnt; 163 foreach( i, h; hiq ) 164 { 165 if( h == 3 ) 166 { 167 mj[i].pos = mj[i].pos + offset[i] / ( 1.0f + n ); 168 mj[i].qual = 1.0f; 169 } 170 else if( h == 2 ) 171 { 172 mj[i].qual = 0.75f; 173 } 174 else if( h == 1 ) 175 { 176 mj[i].pos = sj[i].pos; 177 mj[i].qual = 0.5f; 178 } 179 else 180 { 181 mj[i].qual = 0.0f; 182 } 183 } 184 n++; 185 mean.setJoints( mj ); 186 } 187 result ~= mean; 188 } 189 return result; 190 } 191 } 192 193 unittest 194 { 195 auto tsc = new SimpleComplexer; 196 197 auto us0 = getFakeSkeletons(vec3(0,0,0)); 198 auto us1 = getFakeSkeletons(vec3(0,0,.1)); 199 auto us2 = getFakeSkeletons(vec3(0,0,-.1)); 200 auto by_group = [ [ us1[0], us2[0] ], 201 [ us1[1], us2[1] ] ]; 202 203 auto cmpl = tsc( by_group ); 204 assert( us0 == cmpl ); 205 } 206 207 struct SimpleUserHandlerParams 208 { 209 float max_transform_dist; 210 } 211 212 class SimpleUserHandler: UserHandler 213 { 214 protected: 215 bool is_overdue = true; 216 User self_user; 217 SimpleUserHandlerParams params; 218 public: 219 220 this( User fuser, in SimpleUserHandlerParams suhp ) 221 { 222 self_user = fuser; 223 is_overdue = false; 224 params = suhp; 225 } 226 227 @property 228 { 229 bool respectable() const { return !is_overdue; } 230 ref const(User) user() const { return self_user; } 231 bool isOverdue() const { return is_overdue; } 232 } 233 234 void setOverdue() { is_overdue = true; } 235 236 void setSkeleton( in Skeleton s ) 237 { 238 self_user.skel = s; 239 is_overdue = false; 240 } 241 242 float calcTransformPossibility( in Skeleton s ) const 243 { 244 auto max_dist2 = params.max_transform_dist ^^ 2; 245 auto dist2 = (self_user.skel.torso.pos - s.torso.pos).len2; 246 if( dist2 > max_dist2 ) return 0.0f; 247 return 1.0f / ( dist2 + 0.0001f ); 248 } 249 } 250 251 unittest 252 { 253 auto us0 = getFakeSkeletons(vec3(0,0,0),[vec3(0,0,0)])[0]; 254 auto tsuh = new SimpleUserHandler( User(0,us0), SimpleUserHandlerParams(1.0f) ); 255 assert( tsuh.respectable ); 256 assert( tsuh.user == User(0,us0) ); 257 assert( !tsuh.isOverdue ); 258 259 assert( tsuh.calcTransformPossibility(us0) >= 0.9f / 0.0001f ); 260 assert( tsuh.calcTransformPossibility(skeleton_offset(us0,vec3(1.1,0,0))) == 0.0f ); 261 auto tctpn = tsuh.calcTransformPossibility(skeleton_offset(us0,vec3(0.5,0,0))); 262 assert( tctpn > 0.0f ); 263 assert( tctpn < 1.0f / 0.0001f ); 264 } 265 266 struct SimpleDestributorParams 267 { 268 // TODO: params 269 } 270 271 class SimpleDestributor : Destributor 272 { 273 SimpleDestributorParams params; 274 275 this( in SimpleDestributorParams sdp = SimpleDestributorParams() ) 276 { 277 params = sdp; 278 } 279 280 Skeleton[] opCall( UserHandler[] handlers, in Skeleton[] skeletons ) 281 { 282 auto table = calcPossibility( handlers, skeletons ); 283 bool[] destributed, updated; 284 destributed.length = skeletons.length; 285 updated.length = handlers.length; 286 287 foreach( k; 0 .. min( skeletons.length, handlers.length) ) 288 { 289 float max_possibility = 0; 290 ptrdiff_t max_i = -1; 291 ptrdiff_t max_j = -1; 292 293 foreach( i, skel_line; table ) 294 { 295 if( destributed[i] ) continue; 296 foreach( j, coef; skel_line ) 297 { 298 if( updated[j] ) continue; 299 300 if( coef > max_possibility ) 301 { 302 max_possibility = coef; 303 max_i = i; 304 max_j = j; 305 } 306 } 307 } 308 309 if( max_i >= 0 ) 310 { 311 handlers[max_j].setSkeleton( skeletons[max_i] ); 312 updated[max_j] = true; 313 destributed[max_i] = true; 314 } 315 } 316 317 Skeleton[] not_destributed; 318 foreach( i, skel; skeletons ) 319 if( !destributed[i] ) 320 not_destributed ~= skel; 321 return not_destributed; 322 } 323 324 protected: 325 326 float[][] calcPossibility( UserHandler[] handlers, in Skeleton[] skeletons ) 327 { 328 float[][] ret; 329 foreach( skel; skeletons ) 330 { 331 float[] buf; 332 foreach( uh; handlers ) 333 buf ~= uh.calcTransformPossibility( skel ); 334 ret ~= buf; 335 } 336 return ret; 337 } 338 } 339 340 unittest 341 { 342 auto skels0 = getFakeSkeletons(vec3(0,0,0)); 343 auto tsd = new SimpleDestributor; 344 assert( skels0 == tsd([],skels0) ); 345 UserHandler[] uhlist; 346 foreach( i, s; skels0 ) 347 uhlist ~= new SimpleUserHandler( User(i,s), SimpleUserHandlerParams(1.0f) ); 348 auto skels1 = getFakeSkeletons(vec3(0,0.2,0)); 349 assert( [] == tsd(uhlist,skels1) ); 350 import std.array; 351 auto uhskels = array( map!(a=>a.user.skel)(uhlist) ); 352 assert( uhskels == skels1 ); 353 } 354 355 struct SimpleMultiTrackerFactoryParams 356 { 357 SimpleHeuristicParams heuristic; 358 SimpleClassifierParams classifier; 359 SimpleComplexerParams complexer; 360 SimpleDestributorParams destributor; 361 SimpleUserHandlerParams user; 362 } 363 364 class SimpleMultiTrackerFactory : MultiTrackerFactory 365 { 366 protected: 367 Heuristic _heuristic; 368 Classifier _classifier; 369 Complexer _complexer; 370 Destributor _destributor; 371 372 float max_user_transform_dist = 1.0f; 373 374 SimpleMultiTrackerFactoryParams params; 375 376 public: 377 this( in SimpleMultiTrackerFactoryParams smtfp ) 378 { 379 params = smtfp; 380 _heuristic = new SimpleHeuristic( params.heuristic ); 381 _classifier = new SimpleClassifier( params.classifier ); 382 _complexer = new SimpleComplexer( params.complexer ); 383 _destributor = new SimpleDestributor( params.destributor ); 384 } 385 386 @property 387 { 388 Heuristic heuristic() { return _heuristic; } 389 Classifier classifier() { return _classifier; } 390 Complexer complexer() { return _complexer; } 391 Destributor destributor() { return _destributor; } 392 } 393 394 UserHandler newUserHandler( User user ) 395 { return new SimpleUserHandler( user, params.user ); } 396 } 397 398 version(unittest) 399 { 400 static void printSkeletonsArray( Skeleton[] arr ) 401 { 402 import std.stdio; 403 write( "[ " ); 404 foreach( j, sk; arr ) 405 writef( "Skeleton#%d torso: %s ", j, sk.torso.pos.data ); 406 writeln( " ]" ); 407 } 408 409 static void printSkeletonsDArray( Skeleton[][] arr ) 410 { 411 import std.stdio; 412 if( arr.length == 0 ) 413 { 414 writeln( "empty array" ); 415 return; 416 } 417 writeln( "[ --------- " ); 418 foreach( i, list; arr ) 419 printSkeletonsArray( list ); 420 writeln( " --------- ]" ); 421 } 422 }