50#include "EST_FMatrix.h"
51#include "EST_multistats.h"
64int wgn_min_cluster_size = 50;
65int wgn_max_questions = 2000000;
67float wgn_dropout_feats = 0.0;
68float wgn_dropout_samples = 0.0;
72int wgn_verbose = FALSE;
73int wgn_count_field = -1;
77float wgn_float_range_split = 10;
102#if defined(INSTANTIATE_TEMPLATES)
104#include "../base_class/EST_TList.cc"
105#include "../base_class/EST_TVector.cc"
129 wagon_error(
EST_String(
"unable to open data file \"")+
131 ts.set_PunctuationSymbols(
"");
132 ts.set_PrePunctuationSymbols(
"");
133 ts.set_SingleCharSymbols(
"");
142 if ((type == wndt_float) ||
143 (type == wndt_ols) ||
144 (wgn_count_field == i))
147 float f =
atof(
ts.get().string());
154 dataset.feat_name(i) <<
" vector " <<
156 v->set_flt_val(i,0.0);
159 else if (type == wndt_binary)
160 v->set_int_val(i,
atoi(
ts.get().string()));
161 else if (type == wndt_cluster)
162 v->set_int_val(i,
atoi(
ts.get().string()));
163 else if (type == wndt_vector)
164 v->set_int_val(i,
atoi(
ts.get().string()));
165 else if (type == wndt_trajectory)
171 v->set_int_val(i,
atoi(
ts.get().string()));
173 else if (type == wndt_ignore)
181 int n = wgn_discretes.discrete(type).
name(s);
184 cout <<
fname <<
": bad value " << s <<
" in field " <<
185 dataset.feat_name(i) <<
" vector " <<
197 wagon_error(
fname+
": data vector "+itoString(
nvec)+
" contains "
198 +itoString(i)+
" parameters instead of "+
204 " contains too many parameters instead of "
206 wagon_error(
EST_String(
"extra parameter(s) from ")+
212 cout <<
"Dataset of " <<
dataset.samples() <<
" vectors of " <<
219 if (wgn_test_dataset.samples() != 0)
220 return do_summary(tree,wgn_test_dataset,
output);
222 return do_summary(tree,wgn_dataset,
output);
227 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
228 return test_tree_cluster(tree,
ds,
output);
229 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
230 return test_tree_vector(tree,
ds,
output);
231 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
232 return test_tree_trajectory(tree,
ds,
output);
233 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
234 return test_tree_ols(tree,
ds,
output);
235 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
236 return test_tree_class(tree,
ds,
output);
238 return test_tree_float(tree,
ds,
output);
241WNode *wgn_build_tree(
float &score)
247 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,TRUE);
252 if (wgn_held_out > 0)
254 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,FALSE);
255 top->held_out_prune();
261 score = summary_results(*top,0);
276 for (
j=i=0,d=
ds.head(); d != 0; d=d->next(),
j++)
308 for (p=
dataset.head(); p != 0; p=p->next())
312 if (wgn_count_field == -1)
315 count =
dataset(p)->get_flt_val(wgn_count_field);
316 prob =
pnode->get_impurity().pd().probability(predict);
317 H += (
log(prob))*count;
318 type =
dataset.ftype(wgn_predictee);
319 real = wgn_discretes[type].name(
dataset(p)->get_int_val(wgn_predictee));
321 if (wgn_opt_param ==
"B_NB_F1")
337 pairs.add_item(real,predict,1);
339 for (i=0; i<wgn_discretes[
dataset.ftype(wgn_predictee)].length(); i++)
340 lex.append(wgn_discretes[
dataset.ftype(wgn_predictee)].name(i));
347 *
output <<
";; entropy " << (-1*(H/
total)) <<
" perplexity " <<
353 if (wgn_opt_param ==
"entropy")
355 else if(wgn_opt_param ==
"B_NB_F1")
390 for (p=
dataset.head(); p != 0; p=p->next())
392 leaf = tree.predict_node((*
dataset(p)));
393 pos =
dataset(p)->get_int_val(wgn_predictee);
395 if (wgn_VertexFeats.
a(0,
j) > 0.0)
398 for (
pp=leaf->get_impurity().members.head();
pp != 0;
pp=
pp->next())
400 i = leaf->get_impurity().members.
item(
pp);
401 b += wgn_VertexTrack.
a(i,
j);
405 if (wgn_count_field == -1)
408 count =
dataset(p)->get_flt_val(wgn_count_field);
409 x.cumulate(predict,count);
417 se.cumulate((error*error),count);
418 e.cumulate(
fabs(error),count);
419 xx.cumulate(predict*predict,count);
431 double v1 =
xx.mean()-(
x.mean()*
x.mean());
432 double v2 =
yy.mean()-(
y.mean()*
y.mean());
446 <<
";; RMSE " << ftoString(
sqrt(
se.mean()),4,1)
447 <<
" Correlation is " << ftoString(
cor,4,1)
448 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
449 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
451 cout <<
"RMSE " << ftoString(
sqrt(
se.mean()),4,1)
452 <<
" Correlation is " << ftoString(
cor,4,1)
453 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
454 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
457 if (wgn_opt_param ==
"rmse")
479 for (p=
dataset.head(); p != 0; p=p->next())
481 leaf = tree.predict_node((*
dataset(p)));
482 pos =
dataset(p)->get_int_val(wgn_predictee);
484 if (wgn_VertexFeats.
a(0,
j) > 0.0)
487 for (
pp=leaf->get_impurity().members.head();
pp != 0;
pp=
pp->next())
489 i = leaf->get_impurity().members.
item(
pp);
490 b += wgn_VertexTrack.
a(i,
j);
494 if (wgn_count_field == -1)
497 count =
dataset(p)->get_flt_val(wgn_count_field);
498 x.cumulate(predict,count);
506 se.cumulate((error*error),count);
507 e.cumulate(
fabs(error),count);
508 xx.cumulate(predict*predict,count);
520 double v1 =
xx.mean()-(
x.mean()*
x.mean());
521 double v2 =
yy.mean()-(
y.mean()*
y.mean());
535 <<
";; RMSE " << ftoString(
sqrt(
se.mean()),4,1)
536 <<
" Correlation is " << ftoString(
cor,4,1)
537 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
538 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
540 cout <<
"RMSE " << ftoString(
sqrt(
se.mean()),4,1)
541 <<
" Correlation is " << ftoString(
cor,4,1)
542 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
543 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
546 if (wgn_opt_param ==
"rmse")
561 for (p=
dataset.head(); p != 0; p=p->next())
563 leaf = tree.predict_node((*
dataset(p)));
564 real =
dataset(p)->get_int_val(wgn_predictee);
565 meandist += leaf->get_impurity().cluster_distance(real);
567 ranking += leaf->get_impurity().cluster_ranking(real);
576 "%) mean ranking " <<
ranking.mean() <<
" mean distance "
580 "%) mean ranking " <<
ranking.mean() <<
" mean distance "
596 for (p=
dataset.head(); p != 0; p=p->next())
598 predict = tree.predict((*
dataset(p)));
599 real =
dataset(p)->get_flt_val(wgn_predictee);
600 if (wgn_count_field == -1)
603 count =
dataset(p)->get_flt_val(wgn_count_field);
604 x.cumulate(predict,count);
605 y.cumulate(real,count);
606 error = predict-real;
607 se.cumulate((error*error),count);
608 e.cumulate(
fabs(error),count);
609 xx.cumulate(predict*predict,count);
610 yy.cumulate(real*real,count);
611 xy.cumulate(predict*real,count);
620 double v1 =
xx.mean()-(
x.mean()*
x.mean());
621 double v2 =
yy.mean()-(
y.mean()*
y.mean());
635 <<
";; RMSE " << ftoString(
sqrt(
se.mean()),4,1)
636 <<
" Correlation is " << ftoString(
cor,4,1)
637 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
638 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
640 cout <<
"RMSE " << ftoString(
sqrt(
se.mean()),4,1)
641 <<
" Correlation is " << ftoString(
cor,4,1)
642 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
643 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
646 if (wgn_opt_param ==
"rmse")
662 for (p=
dataset.head(); p != 0; p=p->next())
664 leaf = tree.predict_node((*
dataset(p)));
667 real =
dataset(p)->get_flt_val(wgn_predictee);
668 if (wgn_count_field == -1)
671 count =
dataset(p)->get_flt_val(wgn_count_field);
672 x.cumulate(predict,count);
673 y.cumulate(real,count);
674 error = predict-real;
675 se.cumulate((error*error),count);
676 e.cumulate(
fabs(error),count);
677 xx.cumulate(predict*predict,count);
678 yy.cumulate(real*real,count);
679 xy.cumulate(predict*real,count);
688 double v1 =
xx.mean()-(
x.mean()*
x.mean());
689 double v2 =
yy.mean()-(
y.mean()*
y.mean());
703 <<
";; RMSE " << ftoString(
sqrt(
se.mean()),4,1)
704 <<
" Correlation is " << ftoString(
cor,4,1)
705 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
706 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
708 cout <<
"RMSE " << ftoString(
sqrt(
se.mean()),4,1)
709 <<
" Correlation is " << ftoString(
cor,4,1)
710 <<
" Mean (abs) Error " << ftoString(
e.mean(),4,1)
711 <<
" (" << ftoString(
e.stddev(),4,1) <<
")" <<
endl;
714 if (wgn_opt_param ==
"rmse")
727 if (wgn_max_questions < 1)
730 q = find_best_question(
node.get_data());
745 wgn_find_split(
q,
node.get_data(),l->get_data(),r->get_data());
746 node.set_subnodes(l,r);
747 node.set_question(
q);
751 for (i=0; i <
margin; i++)
768 for (i=0; i <
margin; i++)
770 cout <<
"stopped samples: " <<
node.samples() <<
" impurity: "
783 if (wgn_dropout_samples > 0.0)
786 for (
iy=
in=i=0; i <
ds.n(); i++)
787 if (
q.ask(*
ds(i)) == TRUE)
802 for (
iy=
in=i=0; i <
ds.n(); i++)
803 if (
q.ask(*
ds(i)) == TRUE)
810static float wgn_random_number(
float x)
813 return (((
float)
random())/RAND_MAX)*
x;
824 float*
scores =
new float[wgn_dataset.width()];
829 for (i=0;i < wgn_dataset.width(); i++)
835 for (i=0;i < wgn_dataset.width(); i++)
837 if ((wgn_dataset.ignore(i) == TRUE) ||
838 (i == wgn_predictee))
840 else if (wgn_random_number(1.0) < wgn_dropout_feats)
842 else if (wgn_dataset.ftype(i) == wndt_binary)
847 else if (wgn_dataset.ftype(i) == wndt_float)
851 else if (wgn_dataset.ftype(i) == wndt_ignore)
855 else if (
wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class))
857 wagon_error(
"subset selection temporarily deleted");
861 else if (wgn_dataset.ftype(i) >= wndt_class)
864 for (i=0;i < wgn_dataset.width(); i++)
890 for (i=0;i < wgn_dataset.width(); i++)
892 if ((wgn_dataset.ignore(i) == TRUE) ||
893 (i == wgn_predictee))
895 else if (wgn_random_number(1.0) < wgn_dropout_feats)
897 else if (wgn_dataset.ftype(i) == wndt_binary)
902 else if (wgn_dataset.ftype(i) == wndt_float)
906 else if (wgn_dataset.ftype(i) == wndt_ignore)
910 else if (
wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class))
912 wagon_error(
"subset selection temporarily deleted");
916 else if (wgn_dataset.ftype(i) >= wndt_class)
941 for (
cl=0;
cl < wgn_discretes[wgn_dataset.ftype(
feat)].length();
cl++)
969 ques.set_oper(wnop_is);
970 float *
scores =
new float[wgn_discretes[wgn_dataset.ftype(
feat)].length()];
973 for (
cl=0;
cl < wgn_discretes[wgn_dataset.ftype(
feat)].length();
cl++)
975 ques.set_operand(flocons(
cl));
982 if (siod_llength(order) == 1)
984 ques.set_oper(wnop_is);
985 ques.set_operand(car(order));
986 return scores[get_c_int(car(order))];
989 ques.set_oper(wnop_in);
991 for (l=cdr(order); CDR(l) != NIL; l = cdr(l))
1005 if (siod_llength(
best_l) == 1)
1007 ques.set_oper(wnop_is);
1010 else if (equal(cdr(order),
best_l) != NIL)
1012 ques.set_oper(wnop_is);
1013 ques.set_operand(car(order));
1017 cout <<
"Found a good subset" <<
endl;
1031 for (i=0; i < wgn_discretes[wgn_dataset.ftype(
feat)].length(); i++)
1033 if (
scores[i] != WGN_HUGE_VAL)
1036 items = cons(flocons(i),NIL);
1039 for (l=
items; l != NIL; l=cdr(l))
1043 CDR(l) = cons(car(l),cdr(l));
1044 CAR(l) = flocons(i);
1065 float max,min,val,
incr;
1068 test_q.set_oper(wnop_lessthan);
1071 min = max =
ds(0)->get_flt_val(
feat);
1072 for (d=0; d <
ds.n(); d++)
1074 val =
ds(d)->get_flt_val(
feat);
1081 return WGN_HUGE_VAL;
1082 incr = (max-min)/wgn_float_range_split;
1086 for (i=0,p=min+
incr; i < wgn_float_range_split; i++,p +=
incr )
1123 for (d=0; d <
ds.n(); d++)
1125 if (wgn_random_number(1.0) < wgn_dropout_samples)
1133 if (wgn_count_field == -1)
1136 count = (*wv)[wgn_count_field];
1138 if (
q.ask(*
wv) == TRUE)
1141 if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
1142 y.cumulate(d,count);
1144 y.cumulate((*
wv)[wgn_predictee],count);
1149 if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
1150 n.cumulate(d,count);
1152 n.cumulate((*
wv)[wgn_predictee],count);
1162 if ((wgn_balance == 0.0) ||
1163 (
ds.n()/wgn_balance < wgn_min_cluster_size))
1170 return WGN_HUGE_VAL;
1190 return score_question_set(
q,
ds,1);
1207 for (i=0; i < wgn_dataset.width(); i++)
1208 wgn_dataset.set_ignore(i,TRUE);
1210 for (i=0; i < wgn_dataset.width(); i++)
1212 if ((wgn_dataset.ftype(i) == wndt_ignore) || (i == wgn_predictee))
1233 wgn_dataset.set_ignore(
best_feat,FALSE);
1238 (
const char *)wgn_dataset.feat_name(
best_feat),
1258 for (i=0; i < wgn_dataset.width(); i++)
1260 if (wgn_dataset.ftype(i) == wndt_ignore)
1262 else if (i == wgn_predictee)
1264 else if (wgn_dataset.ignore(i) == TRUE)
1270 wgn_dataset.set_ignore(i,FALSE);
1272 current = wgn_build_tree(score);
1290 wgn_dataset.set_ignore(i,TRUE);
const EST_String & name(const int n) const
The name given the index.
double stddev(void) const
standard deviation of currently cummulated values
double mean(void) const
mean of currently cummulated values
void reset(void)
reset internal values
T & item(const EST_Litem *p)
void resize(int n, int set=1)
float & a(int i, int c=0)
int num_channels() const
return number of channels in track