// -*- C++ -*- // Copyright 2006-2007 Deutsches Forschungszentrum fuer Kuenstliche Intelligenz // or its licensors, as applicable. // // You may not use this file except under the terms of the accompanying license. // // Licensed under the Apache License, Version 2.0 (the "License"); you // may not use this file except in compliance with the License. You may // obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Project: // File: // Purpose: // Responsible: tmb // Reviewer: // Primary Repository: // Web Sites: www.iupr.org, www.dfki.de, www.ocropus.org #include #include #undef CHECK #ifdef GOOGLE_INTERNAL #include "nlp/fst/lib/fst-decl.h" #include "nlp/fst/lib/fst-inl.h" #include "nlp/fst/lib/fstlib-inl.h" #include "nlp/fst/lib/vector-fst-inl.h" #include "nlp/fst/lib/topsort-inl.h" #else #include "fst/lib/fst.h" #include "fst/lib/fstlib.h" #include "fst/lib/vector-fst.h" #include "fst/lib/topsort.h" #endif #undef CHECK #include "fstutil.h" #include "colib.h" #define EPSILON 0 namespace ocropus { using namespace fst; using namespace colib; #if 1 // Compute the best path through the given fst. double bestpath(nustring &result,floatarray &costs,intarray &ids,Fst &fst,bool copy_eps) { fst::Verify(fst); StdVectorFst shortest; ShortestPath(fst,&shortest,1); CHECK_ARG(shortest.NumStates() > 0); result.clear(); costs.clear(); ids.clear(); int currentState = shortest.Start(); for (int i=0; i < shortest.NumStates()-1; i++) { CHECK_ARG(shortest.NumArcs(currentState)==1); ArcIterator aiter(shortest, currentState); const StdArc &arc = aiter.Value(); if(arc.olabel!=EPSILON || copy_eps) { ids.push(arc.ilabel); result.push(nuchar(arc.olabel)); costs.push(arc.weight.Value()); } currentState = arc.nextstate; } // IM: added the final accept cost here return sum(costs) + fst.Final(shortest.NumStates() - 1).Value(); } // Compute the best path through the given fst. double bestpath(nustring &result,Fst &fst,bool copy_eps) { fst::Verify(fst); StdVectorFst shortest; ShortestPath(fst,&shortest,1); CHECK_ARG(shortest.NumStates() > 0); result.clear(); int currentState = shortest.Start(); double total = 0.0; for (int i=0; i < shortest.NumStates()-1; i++) { CHECK_ARG(shortest.NumArcs(currentState)==1); ArcIterator aiter(shortest, currentState); const StdArc &arc = aiter.Value(); if(arc.olabel!=EPSILON || copy_eps) { result.push(nuchar(arc.olabel)); total += arc.weight.Value(); } currentState = arc.nextstate; } return total; } // Compute the best path through the composition of the given fsts. double bestpath2(nustring &result,floatarray &costs,intarray &ids,StdVectorFst &fst,StdVectorFst &fst2,bool copy_eps) { ArcSort(&fst,StdOLabelCompare()); ArcSort(&fst2,StdILabelCompare()); ComposeFst composition(fst,fst2); return bestpath(result,costs,ids,composition,copy_eps); } // Compute the best path through the composition of the given fsts. double bestpath2(nustring &result,StdVectorFst &fst,StdVectorFst &fst2,bool copy_eps) { ArcSort(&fst,StdOLabelCompare()); ArcSort(&fst2,StdILabelCompare()); ComposeFst composition(fst,fst2); return bestpath(result,composition,copy_eps); } #else // FIXME debug these and get them working; TopSort and sequential readout is the preferred way void bestpath(nustring &result, floatarray &costs, intarray &ids,StdVectorFst &fst) { fst::Verify(fst); StdVectorFst shortest; ShortestPath(fst,&shortest,1); fst::TopSort(&fst); CHECK_ARG(shortest.NumStates() > 0); result.resize(shortest.NumStates()-1); fill(result,nuchar('*')); costs.resize(shortest.NumStates()-1); fill(costs,999999); ids.resize(shortest.NumStates()-1); fill(ids,-1); int i=0; for (StateIterator siter(fst); !siter.Done(); siter.Next(),i++) { StdArc::StateId state_id = siter.Value(); ArcIterator aiter(fst, state_id); const StdArc &arc = aiter.Value(); ids[i] = arc.ilabel; result[i] = nuchar(arc.olabel); costs[i] = arc.weight.Value(); } } void bestpath(nustring &result,StdVectorFst &fst) { fst::Verify(fst); StdVectorFst shortest; ShortestPath(fst,&shortest,1); fst::TopSort(&fst); CHECK_ARG(shortest.NumStates() > 0); result.resize(shortest.NumStates()-1); fill(result,nuchar('*')); int i = 0; for (StateIterator siter(fst); !siter.Done(); siter.Next(),i++) { StdArc::StateId state_id = siter.Value(); ArcIterator aiter(fst, state_id); const StdArc &arc = aiter.Value(); result[i] = nuchar(arc.olabel); } } #endif // Convert a string to an fst. StdVectorFst *as_fst(intarray &a,float cost,float skip_cost,float junk_cost) { autodel fst; fst = new StdVectorFst(); int start = fst->AddState(); fst->SetStart(start); int current = start; for(int i=0;iAddState(); check_valid_symbol(a[i]); fst->AddArc(current,StdArc(a[i],a[i],0.0,next)); if(skip_cost<1000) fst->AddArc(current,StdArc(EPSILON,a[i],skip_cost,next)); if(junk_cost<1000) fst->AddArc(current,StdArc(kSigmaLabel,EPSILON,junk_cost,current)); current = next; } fst->SetFinal(current,cost); Verify(*fst); return fst.move(); } // Convert a string to an fst. StdVectorFst *as_fst(const char *s,float cost,float skip_cost,float junk_cost) { intarray a; int n = strlen(s); for(int i=0;i str(as_fst(in)); autodel result(compose(*str,fst)); nustring out; floatarray costs; intarray ids; bestpath(out,costs,ids,*result,true); return sum(costs); } // Score a string against an fst. double score(StdVectorFst &fst,const char *s) { autodel str(as_fst(s)); autodel result(compose(*str,fst)); nustring out; floatarray costs; intarray ids; bestpath(out,costs,ids,*result,true); return sum(costs); } // Score a pair of strings against an fst. double score(intarray &out,StdVectorFst &fst,intarray &in) { autodel in_fst(as_fst(in)); autodel out_fst(as_fst(out)); autodel left_fst(compose(*in_fst,fst)); autodel result(compose(*left_fst,*out_fst)); nustring temp; floatarray costs; intarray ids; bestpath(temp,costs,ids,*result,true); return sum(costs); } // Score a pair of strings against an fst. double score(const char *out,StdVectorFst &fst,const char *in) { autodel in_fst(as_fst(in)); autodel out_fst(as_fst(out)); autodel left_fst(compose(*in_fst,fst)); autodel result(compose(*left_fst,*out_fst)); nustring temp; floatarray costs; intarray ids; bestpath(temp,costs,ids,*result,true); return sum(costs); } // Translate a string using an fst. double translate(intarray &out,StdVectorFst &fst,intarray &in) { autodel str(as_fst(in)); autodel result(compose(*str,fst)); nustring nout; floatarray costs; intarray ids; bestpath(nout,costs,ids,*result); out.clear(); for(int i=0;i str(as_fst(in)); autodel result(compose(*str,fst)); nustring nout; floatarray costs; intarray ids; bestpath(nout,costs,ids,*result); return nout.mallocUtf8Encode(); } // Reverse translate a string using an fst. double reverse_translate(intarray &out,StdVectorFst &fst,intarray &in) { autodel str(as_fst(in)); autodel result(compose(*str,fst)); nustring nout; floatarray costs; bestpath(nout,costs,out,*result); return sum(costs); } // Reverse translate a string using an fst. const char *reverse_translate(StdVectorFst &fst,const char *in) { autodel str(as_fst(in)); autodel result(compose(*str,fst)); nustring nout; floatarray costs; intarray ids; bestpath(nout,costs,ids,*result); return malloc_utf8_encode(ids); } // Sample from an fst. double sample(intarray &out,StdVectorFst &fst) { // FIXME throw "unimplemented"; } // Convenience compose function: perform eager composition after arc sorting. StdVectorFst *compose(StdVectorFst &a,StdVectorFst &b) { autodel result(new StdVectorFst()); ArcSort(&a,StdOLabelCompare()); ArcSort(&b,StdILabelCompare()); Compose(a,b,result.ptr()); return result.move(); } static StdVectorFst *fst_minimize(autodel > &composition,bool rmeps,bool det,bool min) { autodel > epsfree; if(rmeps) epsfree = new RmEpsilonFst(*composition); else epsfree = composition.move(); autodel > determinization; if(det) determinization = new DeterminizeFst(*epsfree); else determinization = epsfree.move(); autodel result(new StdVectorFst(*determinization)); Minimize(result.ptr()); return result.move(); } // Perform composition, then minimization. StdVectorFst *compose(StdVectorFst &a,StdVectorFst &b,bool rmeps,bool det,bool min) { ArcSort(&a,StdOLabelCompare()); ArcSort(&b,StdILabelCompare()); autodel > composition; composition = new ComposeFst(a,b); return fst_minimize(composition,rmeps,det,min); } // Convenience determinization function that returns the result. StdVectorFst *determinize(StdVectorFst &a) { autodel result(new StdVectorFst()); fst::Determinize(a,result.ptr()); return result.move(); } // Convenience difference function that returns the result. StdVectorFst *difference(StdVectorFst &a,StdVectorFst &b) { autodel result(new StdVectorFst()); fst::Difference(a,b,result.ptr()); return result.move(); } // Convenience intersection function that returns the result. StdVectorFst *intersect(StdVectorFst &a,StdVectorFst &b) { autodel result(new StdVectorFst()); fst::Intersect(a,b,result.ptr()); return result.move(); } // Convenience reverse function that returns the result. StdVectorFst *reverse(StdVectorFst &a) { autodel result(new StdVectorFst()); fst::Reverse(a,result.ptr()); return result.move(); } // Prune arcs between states. void fst_prune_arcs(StdVectorFst &result,StdVectorFst &fst,int maxarcs,float maxratio,bool keep_eps) { Arcs f(fst); CHECK_ARG(result.NumStates()==0); for(int i=0;icurrent_to); current_to = to; current_count = 0; current_top = weight; } ASSERT(weight>=current_top); bool above_threshold = (current_count table(new SymbolTable("ASCII")); if(input && !a.InputSymbols()) a.SetInputSymbols(table.ptr()); if(output && !a.OutputSymbols()) a.SetOutputSymbols(table.ptr()); char buf[100]; for(int i=0;i<=32;i++) { if(i==0) { strcpy(buf,"EPSILON"); } else if(i==9) { strcpy(buf,"TAB"); } else if(i==10) { strcpy(buf,"NL"); } else if(i==13) { strcpy(buf,"CR"); } else if(i==32) { strcpy(buf,"SPACE"); } else { sprintf(buf,"%d.",i); } if(input) a.InputSymbols()->AddSymbol(buf,i); if(output) a.OutputSymbols()->AddSymbol(buf,i); } for(int i=33;i<=126;i++) { if(i=='"') { strcpy(buf,"''"); } else { buf[0] = i; buf[1] = 0; } if(input) a.InputSymbols()->AddSymbol(buf,i); if(output) a.OutputSymbols()->AddSymbol(buf,i); } for(int i=127;i<256;i++) { sprintf(buf,"%d.",i); if(input) a.InputSymbols()->AddSymbol(buf,i); if(output) a.OutputSymbols()->AddSymbol(buf,i); } } }