50#include "EST_SCFG_Chart.h"
51#include "EST_simplestats.h"
53#include "EST_TVector.h"
61#if defined(INSTANTIATE_TEMPLATES)
62#include "../base_class/EST_TVector.cc"
78void EST_bracketed_string::init()
87EST_bracketed_string::EST_bracketed_string()
92EST_bracketed_string::EST_bracketed_string(
LISP string)
96 set_bracketed_string(
string);
99EST_bracketed_string::~EST_bracketed_string()
105 for (i=0; i < p_length; i++)
106 delete [] valid_spans[i];
107 delete [] valid_spans;
110void EST_bracketed_string::set_bracketed_string(
LISP string)
116 p_length = find_num_nodes(
string);
117 symbols =
new LISP[p_length];
119 set_leaf_indices(
string,0,symbols);
124 valid_spans =
new int*[length()];
125 for (i=0; i < length(); i++)
127 valid_spans[i] =
new int[length()+1];
128 for (
j=i+1;
j <= length();
j++)
129 valid_spans[i][
j] = 0;
138int EST_bracketed_string::find_num_nodes(
LISP string)
143 else if (CONSP(
string))
144 return find_num_nodes(car(
string))+
145 find_num_nodes(cdr(
string));
150int EST_bracketed_string::set_leaf_indices(
LISP string,
int i,
LISP *
syms)
154 else if (!CONSP(car(
string)))
157 return set_leaf_indices(cdr(
string),i+1,
syms);
161 return set_leaf_indices(cdr(
string),
162 set_leaf_indices(car(
string),i,
syms),
167void EST_bracketed_string::find_valid(
int s,
LISP t)
const
174 for (c=s,l=t; l != NIL; l=cdr(l))
176 c += num_leafs(car(l));
177 valid_spans[s][c] = 1;
179 find_valid(s,car(t));
180 find_valid(s+num_leafs(car(t)),cdr(t));
184int EST_bracketed_string::num_leafs(
LISP t)
const
191 return num_leafs(car(t)) + num_leafs(cdr(t));
194EST_SCFG_traintest::EST_SCFG_traintest(
void) :
EST_SCFG()
202EST_SCFG_traintest::~EST_SCFG_traintest(
void)
209 set_corpus(corpus,vload(filename,1));
213double EST_SCFG_traintest::f_I_cal(
int c,
int p,
int i,
int k)
228 else if (corpus.
a_no_check(c).valid(i,k) == TRUE)
239 for (
j=i+1;
j < k;
j++)
241 double in = f_I(c,
q,i,
j);
251 inside[p][i][k] = res;
259double EST_SCFG_traintest::f_O_cal(
int c,
int p,
int i,
int k)
264 if ((i == 0) && (k == corpus.
a_no_check(c).length()))
266 if (p == distinguished_symbol())
271 else if (corpus.
a_no_check(c).valid(i,k) == TRUE)
288 double out = f_O(c,
q,
j,k);
290 s2 +=
out * f_I(c,r,
j,i);
299 double out = f_O(c,
q,i,
j);
312 outside[p][i][k] = res;
317void EST_SCFG_traintest::reestimate_rule_prob_B(
int c,
int ri,
int p,
int q,
int r)
327 for (i=0; i <= corpus.
a_no_check(c).length()-2; i++)
330 double d1 = f_I(c,
q,i,
j);
331 if (d1 == 0)
continue;
332 for (k=
j+1; k <= corpus.
a_no_check(c).length(); k++)
334 double d2 = f_I(c,r,
j,k);
335 if (d2 == 0)
continue;
336 double d3 = f_O(c,p,i,k);
337 if (
d3 == 0)
continue;
357void EST_SCFG_traintest::reestimate_rule_prob_U(
int c,
int ri,
int p,
int m)
369 for (i=1; i < corpus.
a_no_check(c).length(); i++)
377 d[
ri] += f_P(c,p) /
fP;
381double EST_SCFG_traintest::f_P(
int c)
383 return f_I(c,distinguished_symbol(),0,corpus.
a_no_check(c).length());
386double EST_SCFG_traintest::f_P(
int c,
int p)
391 for (i=0; i < corpus.
a_no_check(c).length(); i++)
394 double d1 = f_O(c,p,i,
j);
395 if (d1 == 0)
continue;
396 db += f_I(c,p,i,
j)*d1;
402void EST_SCFG_traintest::reestimate_grammar_probs(
int passes,
435 if (corpus.
a_no_check(c).length() == 0)
continue;
437 for (
ri=0,r=
rules.head(); r != 0; r=r->next(),
ri++)
439 if (
rules(r).type() == est_scfg_binary_rule)
440 reestimate_rule_prob_B(c,
ri,
442 rules(r).daughter1(),
443 rules(r).daughter2());
445 reestimate_rule_prob_U(c,
448 rules(r).daughter1());
450 lPc += safe_log(f_P(c));
456 for (
se=0.0,
ri=0,r=
rules.head(); r != 0; r=r->next(),
ri++)
464 printf(
"pass %d cross entropy %g RMSE %f %f %d\n",
494void EST_SCFG_traintest::init_io_cache(
int c,
int nt)
500 inside =
new double**[
nt];
501 outside =
new double**[
nt];
502 for (i=0; i <
nt; i++)
504 inside[i] =
new double*[
mc];
505 outside[i] =
new double*[
mc];
506 for (
j=0;
j <
mc;
j++)
508 inside[i][
j] =
new double[
mc];
509 outside[i][
j] =
new double[
mc];
510 for (k=0; k <
mc; k++)
512 inside[i][
j][k] = -1;
513 outside[i][
j][k] = -1;
519void EST_SCFG_traintest::clear_io_cache(
int c)
529 for (
j=0;
j <
mc;
j++)
531 delete [] inside[i][
j];
532 delete [] outside[i][
j];
535 delete [] outside[i];
545double EST_SCFG_traintest::cross_entropy()
550 for (c=0; c < corpus.
length(); c++)
570 for (i=0; i <
rules.length(); i++)
594 cout <<
"cross entropy " << -(
lPc/
mC) <<
" (" <<
failed <<
" failed out of " <<
void train_inout(int passes, int startpass, int checkpoint, int spread, const EST_String &outfile)
void load_corpus(const EST_String &filename)
int num_nonterminals() const
Number of nonterminals.
double prob_B(int p, int q, int r) const
The rule probability of given binary rule.
void set_rule_prob_cache()
(re-)set rule probability caches
SCFGRuleList rules
The rules themselves.
EST_write_status save(const EST_String &filename)
Save current grammar to named file.
EST_String terminal(int m) const
Convert terminal index to string form.
double prob_U(int p, int m) const
The rule probability of given unary rule.
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
INLINE int length() const
number of items in vector.
void fill(const T &v)
Fill entire array will value <parameter>v</parameter>.
INLINE const T & a_no_check(int n) const
read-only const access operator: without bounds checking