1 module des.mc.multitrack.model.simple;
2 
3 import des.mc.multitrack.model;
4 import std.algorithm;
5 import des.mc.multitrack.model.util;
6 
7 struct SimpleHeuristicParams
8 {
9 
10 }
11 
12 class SimpleHeuristic : Heuristic
13 {
14     SimpleHeuristicParams params;
15 
16     this( SimpleHeuristicParams shp )
17     {
18         params = shp;
19     }
20 
21     Skeleton opCall( in Skeleton skel )
22     {
23         // TODO
24         return skel;
25     }
26 }
27 
28 struct SimpleClassifierParams
29 {
30     float min_point_quality=0.5;
31     float class_offset_limit=100;
32     float class_deviation_limit=200;
33 }
34 
35 class SimpleClassifier : Classifier
36 {
37     SimpleClassifierParams params;
38 
39     this( in SimpleClassifierParams scp = SimpleClassifierParams() )
40     {
41         params = scp;
42     }
43 
44     Skeleton[][] opCall( in Skeleton[][] skel_arr )
45     {
46         auto red = plainArray( skel_arr );
47         ClassifierClass[] classes;
48         foreach( skel; red )
49         {
50             auto res = findClass( classes, skel );
51             auto cur = processResult( classes, res[0], res[1] );
52             cur.append( skel );
53         }
54         return getSkeletons( classes );
55     }
56 
57     static auto findClass( ClassifierClass[] classes, in Skeleton skel )
58     {
59         ClassifierClass fnd;
60         float[2] min_diff = [ float.max, float.max ];
61         foreach( cls; classes )
62         {
63             auto df = cls.diff(skel);
64             if( df[0] < min_diff[0] && df[1] < min_diff[1] )
65             {
66                 min_diff = df;
67                 fnd = cls;
68             }
69         }
70         return tuple( fnd, min_diff );
71     }
72 
73     ClassifierClass processResult( ref ClassifierClass[] classes,
74                                        ClassifierClass cls, float[2] diff )
75     {
76         auto ret = cls;
77         if( cls is null || diff[0] > params.class_offset_limit ||
78                            diff[1] > params.class_deviation_limit )
79         {
80             ret = newClassifierClass();
81             classes ~= ret;
82         }
83         return ret;
84     }
85 
86     auto newClassifierClass()
87     {
88         return new ClassifierClass( params.min_point_quality );
89     }
90 
91     static auto getSkeletons( ClassifierClass[] classes )
92     {
93         Skeleton[][] ret;
94         foreach( cls; classes )
95             ret ~= cls.array;
96         return ret;
97     }
98 }
99 
100 unittest
101 {
102     auto tsc = new SimpleClassifier( SimpleClassifierParams(0.5,2) );
103 
104     Skeleton[][] by_tracker;
105     by_tracker ~= getFakeSkeletons(vec3(0,0,.1));
106     by_tracker ~= getFakeSkeletons(vec3(0,0.1,0));
107 
108     auto by_group = tsc( by_tracker );
109 
110     //printSkeletonsDArray( by_tracker );
111     //printSkeletonsDArray( by_group );
112 
113     assert( by_group[0][0] == by_tracker[0][0] );
114     assert( by_group[0][1] == by_tracker[1][0] );
115     assert( by_group[1][0] == by_tracker[0][1] );
116     assert( by_group[1][1] == by_tracker[1][1] );
117 }
118 
119 struct SimpleComplexerParams
120 {
121     // TODO: params 
122 }
123 
124 class SimpleComplexer : Complexer
125 {
126     SimpleComplexerParams params;
127 
128     this( in SimpleComplexerParams scp = SimpleComplexerParams() )
129     {
130         params = scp;
131     }
132 
133     Skeleton[] opCall( in Skeleton[][] skels )
134     {
135         auto min_qual = 0.5;
136         Skeleton[] result;
137         foreach( group; skels )
138         {
139             if( group.length == 0 ) continue;
140 
141             Skeleton mean = group[0];
142             if( group.length < 2 ) { result ~= mean; continue; }
143             size_t n = 1;
144             auto mj = mean.allJoints();
145             foreach( s; group[1 .. $] )
146             {
147                 auto sj = s.allJoints();
148                 auto hiq = new ubyte[]( sj.length );
149                 foreach( i, ref h; hiq )
150                     h = (mj[i].qual > min_qual)*2 + (sj[i].qual > min_qual);
151                 auto offset = new vec3[]( sj.length );
152                 vec3 offset_exp;
153                 size_t offset_exp_cnt;
154                 foreach( i, h; hiq )
155                     if( h == 3 )
156                     {
157                         auto buf = sj[i].pos - mj[i].pos;
158                         offset[i] = buf;
159                         offset_exp += buf;
160                         offset_exp_cnt++;
161                     }
162                 offset_exp /= cast(float)offset_exp_cnt;
163                 foreach( i, h; hiq )
164                 {
165                     if( h == 3 )
166                     {
167                         mj[i].pos = mj[i].pos + offset[i] / ( 1.0f + n );
168                         mj[i].qual = 1.0f;
169                     }
170                     else if( h == 2 )
171                     {
172                         mj[i].qual = 0.75f;
173                     }
174                     else if( h == 1 )
175                     {
176                         mj[i].pos = sj[i].pos;
177                         mj[i].qual = 0.5f;
178                     }
179                     else
180                     {
181                         mj[i].qual = 0.0f;
182                     }
183                 }
184                 n++;
185                 mean.setJoints( mj );
186             }
187             result ~= mean;
188         }
189         return result;
190     }
191 }
192 
193 unittest
194 {
195     auto tsc = new SimpleComplexer;
196 
197     auto us0 = getFakeSkeletons(vec3(0,0,0));
198     auto us1 = getFakeSkeletons(vec3(0,0,.1));
199     auto us2 = getFakeSkeletons(vec3(0,0,-.1));
200     auto by_group = [ [ us1[0], us2[0] ],
201                       [ us1[1], us2[1] ] ];
202 
203     auto cmpl = tsc( by_group );
204     assert( us0 == cmpl );
205 }
206 
207 struct SimpleUserHandlerParams
208 {
209     float max_transform_dist;
210 }
211 
212 class SimpleUserHandler: UserHandler
213 {
214 protected:
215     bool is_overdue = true;
216     User self_user;
217     SimpleUserHandlerParams params;
218 public:
219 
220     this( User fuser, in SimpleUserHandlerParams suhp )
221     {
222         self_user = fuser;
223         is_overdue = false;
224         params = suhp;
225     }
226 
227     @property 
228     {
229         bool respectable() const { return !is_overdue; }
230         ref const(User) user() const { return self_user; }
231         bool isOverdue() const { return is_overdue; }
232     }
233 
234     void setOverdue() { is_overdue = true; }
235 
236     void setSkeleton( in Skeleton s )
237     {
238         self_user.skel = s;
239         is_overdue = false;
240     }
241 
242     float calcTransformPossibility( in Skeleton s ) const
243     {
244         auto max_dist2 = params.max_transform_dist ^^ 2;
245         auto dist2 = (self_user.skel.torso.pos - s.torso.pos).len2;
246         if( dist2 > max_dist2 ) return 0.0f;
247         return 1.0f / ( dist2 + 0.0001f );
248     }
249 }
250 
251 unittest
252 {
253     auto us0 = getFakeSkeletons(vec3(0,0,0),[vec3(0,0,0)])[0];
254     auto tsuh = new SimpleUserHandler( User(0,us0), SimpleUserHandlerParams(1.0f) );
255     assert( tsuh.respectable );
256     assert( tsuh.user == User(0,us0) );
257     assert( !tsuh.isOverdue );
258 
259     assert( tsuh.calcTransformPossibility(us0) >= 0.9f / 0.0001f );
260     assert( tsuh.calcTransformPossibility(skeleton_offset(us0,vec3(1.1,0,0))) == 0.0f );
261     auto tctpn = tsuh.calcTransformPossibility(skeleton_offset(us0,vec3(0.5,0,0)));
262     assert( tctpn > 0.0f );
263     assert( tctpn < 1.0f / 0.0001f );
264 }
265 
266 struct SimpleDestributorParams
267 {
268     // TODO: params
269 }
270 
271 class SimpleDestributor : Destributor
272 {
273     SimpleDestributorParams params;
274 
275     this( in SimpleDestributorParams sdp = SimpleDestributorParams() )
276     {
277         params = sdp;
278     }
279 
280     Skeleton[] opCall( UserHandler[] handlers, in Skeleton[] skeletons )
281     {
282         auto table = calcPossibility( handlers, skeletons );
283         bool[] destributed, updated;
284         destributed.length = skeletons.length;
285         updated.length = handlers.length;
286 
287         foreach( k; 0 .. min( skeletons.length, handlers.length) )
288         {
289             float max_possibility = 0;
290             ptrdiff_t max_i = -1;
291             ptrdiff_t max_j = -1;
292 
293             foreach( i, skel_line; table )
294             {
295                 if( destributed[i] ) continue;
296                 foreach( j, coef; skel_line )
297                 {
298                     if( updated[j] ) continue;
299 
300                     if( coef > max_possibility )
301                     {
302                         max_possibility = coef;
303                         max_i = i;
304                         max_j = j;
305                     }
306                 }
307             }
308 
309             if( max_i >= 0 )
310             {
311                 handlers[max_j].setSkeleton( skeletons[max_i] );
312                 updated[max_j] = true;
313                 destributed[max_i] = true;
314             }
315         }
316 
317         Skeleton[] not_destributed;
318         foreach( i, skel; skeletons )
319             if( !destributed[i] )
320                 not_destributed ~= skel;
321         return not_destributed;
322     }
323 
324 protected:
325 
326     float[][] calcPossibility( UserHandler[] handlers, in Skeleton[] skeletons )
327     {
328         float[][] ret;
329         foreach( skel; skeletons )
330         {
331             float[] buf;
332             foreach( uh; handlers )
333                 buf ~= uh.calcTransformPossibility( skel );
334             ret ~= buf;
335         }
336         return ret;
337     }
338 }
339 
340 unittest
341 {
342     auto skels0 = getFakeSkeletons(vec3(0,0,0));
343     auto tsd = new SimpleDestributor;
344     assert( skels0 == tsd([],skels0) );
345     UserHandler[] uhlist;
346     foreach( i, s; skels0 )
347         uhlist ~= new SimpleUserHandler( User(i,s), SimpleUserHandlerParams(1.0f) );
348     auto skels1 = getFakeSkeletons(vec3(0,0.2,0));
349     assert( [] == tsd(uhlist,skels1) );
350     import std.array;
351     auto uhskels = array( map!(a=>a.user.skel)(uhlist) );
352     assert( uhskels == skels1 );
353 }
354 
355 struct SimpleMultiTrackerFactoryParams
356 {
357     SimpleHeuristicParams heuristic;
358     SimpleClassifierParams classifier;
359     SimpleComplexerParams complexer;
360     SimpleDestributorParams destributor;
361     SimpleUserHandlerParams user;
362 }
363 
364 class SimpleMultiTrackerFactory : MultiTrackerFactory
365 {
366 protected:
367     Heuristic _heuristic;
368     Classifier _classifier;
369     Complexer _complexer;
370     Destributor _destributor;
371 
372     float max_user_transform_dist = 1.0f;
373 
374     SimpleMultiTrackerFactoryParams params;
375 
376 public:
377     this( in SimpleMultiTrackerFactoryParams smtfp )
378     {
379         params = smtfp;
380         _heuristic = new SimpleHeuristic( params.heuristic );
381         _classifier = new SimpleClassifier( params.classifier );
382         _complexer = new SimpleComplexer( params.complexer );
383         _destributor = new SimpleDestributor( params.destributor );
384     }
385 
386     @property
387     {
388         Heuristic heuristic() { return _heuristic; }
389         Classifier classifier() { return _classifier; }
390         Complexer complexer() { return _complexer; }
391         Destributor destributor() { return _destributor; }
392     }
393 
394     UserHandler newUserHandler( User user )
395     { return new SimpleUserHandler( user, params.user ); }
396 }
397 
398 version(unittest)
399 {
400     static void printSkeletonsArray( Skeleton[] arr )
401     {
402         import std.stdio;
403         write( "[ " );
404         foreach( j, sk; arr )
405             writef( "Skeleton#%d torso: %s ", j, sk.torso.pos.data );
406         writeln( " ]" );
407     }
408 
409     static void printSkeletonsDArray( Skeleton[][] arr )
410     {
411         import std.stdio;
412         if( arr.length == 0 ) 
413         {
414             writeln( "empty array" );
415             return;
416         }
417         writeln( "[ --------- " );
418         foreach( i, list; arr )
419             printSkeletonsArray( list );
420         writeln( "  --------- ]" );
421     }
422 }