View Javadoc

1   // FindBestDomainOrder.java, created Aug 21, 2004 1:17:30 AM by joewhaley
2   // Copyright (C) 2004 John Whaley <jwhaley@alum.mit.edu>
3   // Licensed under the terms of the GNU LGPL; see COPYING for details.
4   package net.sf.bddbddb;
5   
6   import java.util.ArrayList;
7   import java.util.Arrays;
8   import java.util.Collection;
9   import java.util.Collections;
10  import java.util.Comparator;
11  import java.util.Enumeration;
12  import java.util.HashMap;
13  import java.util.HashSet;
14  import java.util.Iterator;
15  import java.util.LinkedHashSet;
16  import java.util.LinkedList;
17  import java.util.List;
18  import java.util.Map;
19  import java.util.Set;
20  import java.util.SortedSet;
21  import java.util.TreeSet;
22  import java.io.BufferedWriter;
23  import java.io.File;
24  import java.io.FileOutputStream;
25  import java.io.FileWriter;
26  import java.io.IOException;
27  import java.io.PrintStream;
28  import java.net.URL;
29  import java.text.NumberFormat;
30  import java.text.SimpleDateFormat;
31  import jwutil.collections.FlattenedCollection;
32  import jwutil.collections.GenericMultiMap;
33  import jwutil.collections.MultiMap;
34  import jwutil.collections.Pair;
35  import jwutil.io.SystemProperties;
36  import jwutil.strings.Strings;
37  import jwutil.util.Assert;
38  import net.sf.bddbddb.order.AttribToDomainTranslator;
39  import net.sf.bddbddb.order.CandidateSampler;
40  import net.sf.bddbddb.order.ConstraintInfo;
41  import net.sf.bddbddb.order.Discretization;
42  import net.sf.bddbddb.order.EpisodeCollection;
43  import net.sf.bddbddb.order.MapBasedTranslator;
44  import net.sf.bddbddb.order.MyId3;
45  import net.sf.bddbddb.order.Order;
46  import net.sf.bddbddb.order.OrderConstraint;
47  import net.sf.bddbddb.order.OrderConstraintSet;
48  import net.sf.bddbddb.order.OrderTranslator;
49  import net.sf.bddbddb.order.Queue;
50  import net.sf.bddbddb.order.StackQueue;
51  import net.sf.bddbddb.order.TrialDataRepository;
52  import net.sf.bddbddb.order.TrialGuess;
53  import net.sf.bddbddb.order.TrialInfo;
54  import net.sf.bddbddb.order.TrialInstance;
55  import net.sf.bddbddb.order.TrialInstances;
56  import net.sf.bddbddb.order.TrialPrediction;
57  import net.sf.bddbddb.order.VarToAttribTranslator;
58  import net.sf.bddbddb.order.WekaInterface;
59  import net.sf.bddbddb.order.CandidateSampler.UncertaintySampler;
60  import net.sf.bddbddb.order.EpisodeCollection.Episode;
61  import net.sf.bddbddb.order.TrialDataRepository.TrialDataGroup;
62  import net.sf.bddbddb.order.WekaInterface.OrderAttribute;
63  import net.sf.bddbddb.order.WekaInterface.OrderInstance;
64  import net.sf.javabdd.BDD;
65  import net.sf.javabdd.BDDFactory;
66  import net.sf.javabdd.BDDVarSet;
67  import net.sf.javabdd.FindBestOrder;
68  import org.jdom.Document;
69  import org.jdom.Element;
70  import org.jdom.input.SAXBuilder;
71  import weka.classifiers.Classifier;
72  import weka.core.Instance;
73  import weka.core.Instances;
74  
75  /***
76   * FindBestDomainOrder
77   * 
78   * Design:
79   * 
80   * TrialInfo : order, cost
81   * EpisodeCollection : collection of TrialInfo, best time, worst time
82   * Constraint: a<b or axb or a_b
83   * Order : collection of ordering constraints
84   * ConstraintInfo : map from a single constraint to score/confidence
85   * OrderInfo : order, predicted score and confidence
86   * 
87   * Maps:
88   *  Relation -> ConstraintInfo collection
89   *   Rule -> ConstraintInfo collection
90   * EpisodeCollection -> ConstraintInfo collection
91   * 
92   * Algorithm to compute best order:
93   * - Combine and sort single constraints from relation, rule, trials so far.
94   *   Sort by score*confidence (?)
95   *   Combine and adjust opposite constraints (?)
96   *   Sort by difference between opposite constraints (?)
97   * - Do an A* search.
98   *   Keep track of the current score/confidence as we add constraints.
99   *   As we add new constraints, flag conflicting ones.
100  *   Predict final score by combining top n non-conflicting constraints (?)
101  *   If prediction is worse than current best score, return immediately.
102  * 
103  * @author John Whaley
104  * @version $Id: FindBestDomainOrder.java 645 2006-07-17 05:20:20Z joewhaley $
105  */
106 public class FindBestDomainOrder {
107 
108 
109     public static int TRACE = 2;
110 
111     public static PrintStream out;
112     public static PrintStream out_t;
113 
114     public static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyMMdd-HHmmss");
115 
116     /***
117      * Link back to the solver.
118      */
119     BDDSolver solver;
120 
121     /***
122      * Collection of all EpisodeCollections that have been done so far, including
123      * ones that have been loaded from disk.
124      */
125     Collection allTrials;
126 
127     TrialDataRepository dataRepository;
128 
129     /***
130      * Whether we should keep track of per-rule constraints, in addition to global
131      * constraints.
132      */
133     public static boolean PER_RULE_CONSTRAINTS = true;
134 
135     public static boolean DUMP_CLASSIFIER_INFO = true;
136 
137     
138     /***
139      * Info collection for each of the constraints.
140      */
141     ConstraintInfoCollection constraintInfo;
142 
143     /***
144      * Construct a new empty FindBestDomainOrder object.
145      */
146     public FindBestDomainOrder(Solver s) {
147         constraintInfo = new ConstraintInfoCollection(s);
148         allTrials = new LinkedList();
149         if (s instanceof BDDSolver){
150             solver = (BDDSolver) s;
151             dataRepository = new TrialDataRepository(allTrials, solver);
152         }
153         out = solver.out;
154     }
155         
156 
157     /***
158      * Construct a new FindBestDomainOrder object with the given info.
159      */
160     FindBestDomainOrder(ConstraintInfoCollection c) {
161         constraintInfo = c;
162         allTrials = new LinkedList();
163         if (c.solver instanceof BDDSolver)
164             solver = (BDDSolver) c.solver;
165         out = solver.out;
166     }
167 
168     /***
169      * Load and incorporate trials from the given XML file.
170      * 
171      * @param filename  filename
172      */
173     void loadTrials(String filename) {
174         out.println("Trials filename=" + filename);
175         File file = new File(filename);
176         if (file.exists()) {
177             try {
178                 URL url = file.toURL();
179                 SAXBuilder builder = new SAXBuilder();
180                 Document doc = builder.build(url);
181                 XMLFactory f = new XMLFactory(solver);
182                 Element e = doc.getRootElement();
183                 List list = (List) f.fromXML(e);
184                 if (TRACE > 0) {
185                     out.println("Loaded " + list.size() + " trial collections from file.");
186                     if (TRACE > 2) {
187                         for (Iterator i = list.iterator(); i.hasNext();) {
188                             out.println("Loaded from file: " + i.next());
189                         }
190                     }
191                 }
192                 allTrials.addAll(list);
193             } catch (Exception e) {
194                 solver.err.println("Error occurred loading " + filename + ": " + e);
195                 e.printStackTrace();
196             }
197         }
198         incorporateTrials();
199     }
200 
201     /***
202      * Incorporate all of the trials in allTrials.
203      */
204     void incorporateTrials() {
205         for (Iterator i = allTrials.iterator(); i.hasNext();) {
206             EpisodeCollection tc = (EpisodeCollection) i.next();
207             constraintInfo.addTrials(tc);
208         }
209     }
210 
211     void incorporateTrial(Episode ep) {
212         constraintInfo.addTrials(ep.getEpisodeCollection());
213 
214         if (TRACE > 2)
215             dump();
216     }
217 
218     /***
219      * Dump the collected order info for rules and relations to standard output.
220      */
221     public void dump() {
222         SortedSet set = new TreeSet(new Comparator() {
223             public int compare(Object o1, Object o2) {
224                 ConstraintInfo info1 = (ConstraintInfo) ((Map.Entry) o1).getValue();
225                 ConstraintInfo info2 = (ConstraintInfo) ((Map.Entry) o2).getValue();
226                 return info1.compareTo(info2);
227             }
228         });
229         set.addAll(constraintInfo.infos.entrySet());
230         for (Iterator i = set.iterator(); i.hasNext();) {
231             Map.Entry e = (Map.Entry) i.next();
232             OrderConstraint ir = (OrderConstraint) e.getKey();
233             out.println("Order feature: " + ir);
234             ConstraintInfo info = (ConstraintInfo) e.getValue();
235             info.dump();
236         }
237     }
238 
239     public int getNumberOfTrials() {
240         int sum = 0;
241         for (Iterator i = allTrials.iterator(); i.hasNext();) {
242             EpisodeCollection ec = (EpisodeCollection) i.next();
243             sum += ec.getNumTrials();
244         }
245         return sum;
246     }
247     
248     /***
249      * Starts a new trial collection and returns it.
250      * 
251      * @param rule  inference rule of trial collection
252      * @param opNumber  operation number of trial collection
253      * @param timeStamp  time of trial collection
254      * @param newCollection  whether to always return a new collection
255      * @return new trial collection
256      */
257     public Episode getNewEpisode(BDDInferenceRule rule, int opNumber, long timeStamp, boolean newCollection) {
258         EpisodeCollection c = newCollection ? findEpisodeCollection(rule, opNumber) : null;
259 
260         if(c == null){
261            c = new EpisodeCollection(rule, opNumber);
262            allTrials.add(c);
263        }
264        return c.startNewEpisode(timeStamp);
265     }
266    
267     
268     public EpisodeCollection findEpisodeCollection(BDDInferenceRule rule, int opNumber){
269         for(Iterator it = allTrials.iterator(); it.hasNext(); ){
270             EpisodeCollection tc = (EpisodeCollection) it.next();
271             if(tc.getRule(solver) == rule && tc.getUpdateCount() == rule.updateCount && tc.getOpNumber() == opNumber){
272                 if(TRACE > 1) out.println("Found a tc for:  rule " + rule.id + " on update " + rule.updateCount + " and op " + opNumber);
273                 return tc;
274             }
275         }
276         return null;
277     }
278 
279     /***
280      * Calculated information about an order.  This consists of a score
281      * and an estimated information gain.
282      * 
283      * @author John Whaley
284      * @version $Id: FindBestDomainOrder.java 645 2006-07-17 05:20:20Z joewhaley $
285 
286      */
287     public static class OrderInfo implements Comparable {
288 
289         /***
290          * The order this information is about.
291          */
292         Order order;
293 
294         /***
295          * A measure of how good this order is.
296          */
297         double score;
298 
299         /***
300          * A measure of the expected information gain from running this order.
301          */
302         double infoGain;
303 
304         /***
305          * Construct a new OrderInfo.
306          * 
307          * @param o  order
308          * @param s  score
309          * @param c  info gain
310          */
311         public OrderInfo(Order o, double s, double c) {
312             this.order = o;
313             this.score = s;
314             this.infoGain = c;
315         }
316 
317         /***
318          * Construct a new OrderInfo that is a clone of another.
319          * 
320          * @param that  other OrderInfo to clone from
321          */
322         public OrderInfo(OrderInfo that) {
323             this.order = that.order;
324             this.score = that.score;
325             this.infoGain = that.infoGain;
326         }
327 
328         /* (non-Javadoc)
329          * @see java.lang.Object#toString()
330          */
331         public String toString() {
332             return order + ": score " + format(score) + " info gain " + format(infoGain);
333         }
334 
335         /* (non-Javadoc)
336          * @see java.lang.Comparable#compareTo(java.lang.Object)
337          */
338         public int compareTo(Object arg0) {
339             return compareTo((OrderInfo) arg0);
340         }
341 
342         /***
343          * Comparison operator for OrderInfo objects.  Score is most important, followed
344          * by info gain.  If both are equal, we compare the order lexigraphically.
345          * 
346          * @param that  OrderInfo to compare to
347          * @return  -1, 0, or 1 if this OrderInfo is less than, equal to, or greater than the other
348          */
349         public int compareTo(OrderInfo that) {
350             if (this == that) return 0;
351             int result = signum(that.score - this.score);
352             if (result == 0) {
353                 result = (int) signum(this.infoGain - that.infoGain);
354                 if (result == 0) {
355                     result = this.order.compareTo(that.order);
356                 }
357             }
358             return result;
359         }
360 
361         /***
362          * Returns this OrderInfo as an XML element.
363          * 
364          * @return XML element
365          */
366         public Element toXMLElement() {
367             Element dis = new Element("orderInfo");
368             dis.setAttribute("order", order.toString());
369             dis.setAttribute("score", Double.toString(score));
370             dis.setAttribute("infoGain", Double.toString(infoGain));
371             return dis;
372         }
373 
374         public static OrderInfo fromXMLElement(Element e, Map nameToVar) {
375             Order o = Order.parse(e.getAttributeValue("order"), nameToVar);
376             double s = Double.parseDouble(e.getAttributeValue("score"));
377             double c = Double.parseDouble(e.getAttributeValue("infoGain"));
378             return new OrderInfo(o, s, c);
379         }
380     }
381     
382     /***
383      * Generate all orders of a given list of variables.
384      * 
385      * @param vars  list of variables
386      * @return  list of all orders of those variables
387      */
388     public static List/*<Order>*/ generateAllOrders(List vars) {
389         if (vars.size() == 0) return null;
390         LinkedList result = new LinkedList();
391         if (vars.size() == 1) {
392             result.add(new Order(vars));
393             return result;
394         }
395         Object car = vars.get(0);
396         List recurse = generateAllOrders(vars.subList(1, vars.size()));
397         for (Iterator i = recurse.iterator(); i.hasNext();) {
398             Order order = (Order) i.next();
399             for (int j = 0; j <= order.size(); ++j) {
400                 Order myOrder = new Order(order);
401                 myOrder.add(j, car);
402                 result.add(myOrder);
403             }
404         }
405         for (Iterator i = recurse.iterator(); i.hasNext();) {
406             Order order = (Order) i.next();
407             for (int j = 0; j < order.size(); ++j) {
408                 Order myOrder = new Order(order);
409                 Object o = myOrder.get(j);
410                 List c = new LinkedList();
411                 c.add(car);
412                 if (o instanceof Collection) {
413                     c.addAll((Collection) o);
414                 } else {
415                     c.add(o);
416                 }
417                 myOrder.set(j, c);
418                 result.add(myOrder);
419             }
420         }
421         return result;
422     }
423 
424     transient static NumberFormat nf;
425 
426     /***
427      * Format a double in a nice way.
428      * 
429      * @param d  double
430      * @return string representation
431      */
432     public static String format(double d) {
433         if (nf == null) {
434             nf = NumberFormat.getNumberInstance();
435             //nf.setMinimumFractionDigits(3);
436             nf.setMaximumFractionDigits(3);
437         }
438         if (d == Double.MAX_VALUE) return "max";
439         return nf.format(d);
440     }
441     
442     public static String format(double d, int numFracDigits) {
443         if (nf == null) {
444             nf = NumberFormat.getNumberInstance();
445             //nf.setMinimumFractionDigits(3);
446             nf.setMaximumFractionDigits(numFracDigits);
447         }
448         if (d == Double.MAX_VALUE) return "max";
449         return nf.format(d);
450     }
451 
452     // Only present in JDK1.5
453     public static int signum(long d) {
454         if (d < 0) return -1;
455         if (d > 0) return 1;
456         return 0;
457     }
458 
459     // Only present in JDK1.5
460     public static int signum(double d) {
461         if (d < 0) return -1;
462         if (d > 0) return 1;
463         return 0;
464     }
465 
466     public static class ConstraintInfoCollection {
467 
468         Solver solver;
469 
470         /***
471          * Map from orders to their info.
472          */
473         Map/* <OrderConstraint,ConstraintInfo> */infos;
474 
475         public ConstraintInfoCollection(Solver s) {
476             this.solver = s;
477             this.infos = new HashMap();
478         }
479 
480         public ConstraintInfo getInfo(OrderConstraint c) {
481             return (ConstraintInfo) infos.get(c);
482         }
483 
484         public ConstraintInfo getOrCreateInfo(OrderConstraint c) {
485             ConstraintInfo ci = (ConstraintInfo) infos.get(c);
486             if (ci == null) infos.put(c, ci = new ConstraintInfo(c));
487             return ci;
488         }
489 
490         private void addTrials(EpisodeCollection tc, OrderTranslator trans) {
491             MultiMap c2Trials = new GenericMultiMap();
492 
493             for (Iterator i = tc.getTrials().iterator(); i.hasNext();) {
494                 TrialInfo ti = (TrialInfo) i.next();
495                 Order o = ti.order;
496                 if (ti.cost >= BDDInferenceRule.LONG_TIME)
497                     ((BDDSolver) solver).fbo.neverTryAgain(tc.getRule(solver), o);
498                 if (trans != null) o = trans.translate(o);
499                 Collection ocs = o.getConstraints();
500                 for (Iterator j = ocs.iterator(); j.hasNext();) {
501                     OrderConstraint oc = (OrderConstraint) j.next();
502                     c2Trials.add(oc, ti);
503                 }
504 
505             }
506 
507             for (Iterator i = c2Trials.keySet().iterator(); i.hasNext();) {
508                 OrderConstraint oc = (OrderConstraint) i.next();
509                 ConstraintInfo info = getOrCreateInfo(oc);
510                 info.registerTrials(c2Trials.getValues(oc));
511             }
512             
513         }
514 
515         public void addTrials(EpisodeCollection tc) {
516             InferenceRule ir = tc.getRule(solver);
517             OrderTranslator varToAttrib = new VarToAttribTranslator(ir);
518             addTrials(tc, varToAttrib);
519             if (PER_RULE_CONSTRAINTS) {
520                 addTrials(tc, null);
521             }
522             if (TRACE > 2) {
523                 out.println("Added trial collection: " + tc);
524             }
525         }
526 
527         public OrderInfo predict(Order o, OrderTranslator trans) {
528             if (TRACE > 2) out.println("Predicting order "+o);
529             if (trans != null) o = trans.translate(o);
530             if (TRACE > 2) out.println("Translated into order "+o);
531             double score = 0.;
532             int numTrialCollections = 0, numTrials = 0;
533             Collection cinfos = new LinkedList();
534             for (Iterator i = o.getConstraints().iterator(); i.hasNext();) {
535                 OrderConstraint c = (OrderConstraint) i.next();
536                 ConstraintInfo ci = getInfo(c);
537                 if (ci == null || ci.getNumberOfTrials() == 0) continue;
538                 cinfos.add(ci);
539                 score += ci.getWeightedMean();
540                 numTrialCollections++;
541                 numTrials += ci.getNumberOfTrials();
542             }
543             if (numTrialCollections == 0)
544                 score = 0.;
545             else
546                 score = score / numTrialCollections;
547             double infoGain = ConstraintInfo.getVariance(cinfos) / numTrials;
548             if (TRACE > 2) out.println("Prediction for "+o+": score "+format(score)+" infogain "+format(infoGain));
549             return new OrderInfo(o, score, infoGain);
550         }
551 
552     }
553 
554     public OrderInfo predict(Order o, OrderTranslator trans) {
555         return constraintInfo.predict(o, trans);
556     }
557 
558     /***
559      * Returns this FindBestDomainOrder as an XML element.
560      * 
561      * @return XML element
562      */
563     public Element toXMLElement() {
564         Element constraintInfoCollection = new Element("constraintInfoCollection");
565         for (Iterator i = constraintInfo.infos.entrySet().iterator(); i.hasNext();) {
566             Map.Entry e = (Map.Entry) i.next();
567             OrderConstraint oc = (OrderConstraint) e.getKey();
568             ConstraintInfo c = (ConstraintInfo) e.getValue();
569             Element constraintInfo = c.toXMLElement(solver);
570             constraintInfoCollection.addContent(constraintInfo);
571         }
572 
573         Element fbo = new Element("findBestOrder");
574         fbo.addContent(constraintInfoCollection);
575 
576         return fbo;
577     }
578 
579     /***
580      * Returns the set of all trials performed so far as an XML element.
581      * 
582      * @return XML element
583      */
584     public Element trialsToXMLElement() {
585         Element trialCollections = new Element("episodeCollections");
586         if (solver.inputFilename != null)
587             trialCollections.setAttribute("datalog", solver.inputFilename);
588         for (Iterator i = allTrials.iterator(); i.hasNext();) {
589             EpisodeCollection c = (EpisodeCollection) i.next();
590             trialCollections.addContent(c.toXMLElement());
591         }
592         return trialCollections;
593     }
594 
595     /***
596      */
597     public boolean hasOrdersToTry(List allVars, BDDInferenceRule ir) {
598         // TODO: improve this code.
599         int nTrials = getNumberOfTrials();
600         if (nTrials != ir.lastTrialNum) {
601             ir.lastTrialNum = nTrials;
602             TrialGuess g = this.tryNewGoodOrder(null, allVars, ir, -2, null, false);
603             return g != null;
604         } else {
605             return false;
606         }
607     }
608 
609     // Since JDK1.4 only.
610     public static final int compare(double d1, double d2) {
611         if (d1 < d2)
612             return -1; // Neither val is NaN, thisVal is smaller
613         if (d1 > d2)
614             return 1; // Neither val is NaN, thisVal is larger
615 
616         long thisBits = Double.doubleToLongBits(d1);
617         long anotherBits = Double.doubleToLongBits(d2);
618 
619         return (thisBits == anotherBits ? 0 : // Values are equal
620                 (thisBits < anotherBits ? -1 : // (-0.0, 0.0) or (!NaN, NaN)
621                         1)); // (0.0, -0.0) or (NaN, !NaN)
622     }
623 
624 
625 
626     public final static int WEIGHT_WINDOW_SIZE = Integer.MAX_VALUE;
627 
628     public final static double DECAY_FACTOR = -.1;
629 
630     public double computeWeight(int type, TrialInstances instances) {
631         int numTrials = 0;
632         double total = 0;
633         int losses = 0;
634         double weight = 1;
635         for (int i = instances.numInstances() - 1; i >= 0 && numTrials < WEIGHT_WINDOW_SIZE; --i) {
636             TrialInstance instance = (TrialInstance) instances.instance(i);
637             double trueCost = instance.getCost();
638 
639             TrialPrediction pred = instance.getTrialInfo().pred;
640             double predCost = pred.predictions[type][TrialPrediction.LOW];
641             double dev = pred.predictions[type][TrialPrediction.HIGH];
642 
643             if(predCost == -1) continue;
644             double trialWeight = Math.exp(DECAY_FACTOR * numTrials);
645             if (trueCost < predCost - dev || trueCost > predCost + dev) {
646                 losses += trialWeight;
647             }
648             total += trialWeight;
649             ++numTrials;
650         }
651         if (numTrials != 0) {
652             weight = 1 - losses / (double) total;
653         }
654         return weight;
655     }
656 
657     public void adjustWeights(TrialInstances vData, TrialInstances aData, TrialInstances dData) {
658         if (vData != null)
659             varClassWeight = computeWeight(TrialPrediction.VARIABLE, vData);
660         if (aData != null)
661             attrClassWeight = computeWeight(TrialPrediction.ATTRIBUTE, aData);
662         if (dData != null)
663             domClassWeight = computeWeight(TrialPrediction.DOMAIN, dData);
664     }
665 
666     public static int NUM_CV_FOLDS = 10;
667 
668      /***
669      * @param data
670      * @param cClassName
671      * @return Cross validation with number of folds as set by NUM_CV_FOLDS;
672      */
673     public double constFoldCV(Instances data, String cClassName) {
674         return WekaInterface.cvError(NUM_CV_FOLDS, data, cClassName);
675     }
676 
677 
678     public static boolean DISCRETIZE1 = true;
679     public static boolean DISCRETIZE2 = true;
680     public static boolean DISCRETIZE3 = true;
681     public static String CLASSIFIER1 = "net.sf.bddbddb.order.MyId3";
682     public static String CLASSIFIER2 = "net.sf.bddbddb.order.MyId3";
683     public static String CLASSIFIER3 = "net.sf.bddbddb.order.MyId3";
684 
685     
686     public void neverTryAgain(InferenceRule ir, Order o) {
687           if (true) {
688             if (TRACE > 2) out.println("For rule"+ir.id+", never trying order "+o+" again.");
689             neverAgain.add(ir, o);
690         }
691     }
692 
693     MultiMap neverAgain = new GenericMultiMap();
694 
695     public double varClassWeight = 1;
696     public double attrClassWeight = 1;
697     public double domClassWeight = 1;
698     public static int DOMAIN_THRESHOLD = 1000;
699     public static int NO_CLASS = -1;
700     public static int NO_CLASS_SCORE = -1;
701     public boolean PROBABILITY = false;
702     void dumpClassifierInfo(String name, Classifier c, Instances data) {
703         BufferedWriter w = null;
704         try {
705             w = new BufferedWriter(new FileWriter(name));
706             w.write("Classifier \"name\":\n");
707             w.write("Attributes: \n");
708             for(Enumeration e = data.enumerateAttributes(); e.hasMoreElements(); ){
709                 w.write(e.nextElement() +"\n");
710             }
711             w.write("\n");
712             w.write("Based on data from "+data.numInstances()+" instances:\n");
713             for (Enumeration e = data.enumerateInstances(); e.hasMoreElements(); ) {
714                 Instance i = (Instance) e.nextElement();
715 
716                 if (i instanceof TrialInstance) {
717                     TrialInstance ti = (TrialInstance) i;
718                     InferenceRule ir = ti.ti.getCollection().getRule(solver);
719                     w.write("    "+ti.ti.getCollection().name+" "+ti.getOrder());
720                     if (!ti.getOrder().equals(ti.ti.order))
721                         w.write(" ("+ti.ti.order+")");
722                     if (ti.isMaxTime()) {
723                         w.write(" MAX TIME\n");
724                     } else {
725                         w.write(" "+format(ti.getCost())+" ("+ti.ti.cost+" ms)\n");
726                     }
727                 } else {
728                     w.write("    "+i+"\n");
729                 }
730             }
731             w.write(c.toString());
732             w.write("\n");
733         } catch (IOException x) { 
734             solver.err.println("IO Exception occurred writing \""+name+"\": "+x);
735         } finally {
736             if (w != null) try { w.close(); } catch (IOException _) { }
737         }
738     }
739     
740     void dumpTrialGuessInfo(String name) {
741         BufferedWriter w = null;
742         try {
743             w = new BufferedWriter(new FileWriter(name, true));
744             w.write("Classifier \"name\":\n");
745             w.write("\n");
746         } catch (IOException x) { 
747             solver.err.println("IO Exception occurred writing \""+name+"\": "+x);
748         } finally {
749             if (w != null) try { w.close(); } catch (IOException _) { }
750         }
751     }
752 
753     
754     private void addTrial(InferenceRule rule, List variables, Episode ep, Order o, TrialPrediction prediction, long time, long timestamp) {
755         
756         TrialInfo info = new TrialInfo(o,prediction,time,ep, timestamp);
757       /*  tc.addTrial(o,guess.prediction, time);
758        */
759         ep.addTrial(info);
760         dataRepository.addTrial(rule, variables, info);
761     }
762     
763     public static int INITIAL_VAR_SET = 10;
764     public static int INITIAL_ATTRIB_SET = 16;
765     public static int INITIAL_DOM_SET = 10;
766     public TrialGuess tryNewGoodOrder(Episode ep, List allVars, InferenceRule ir, int opNum,
767             boolean returnBest) {
768         return tryNewGoodOrder(ep.getEpisodeCollection(), allVars, ir, opNum, null, returnBest);
769     }
770     
771     public TrialGuess tryNewGoodOrder(EpisodeCollection ec, List allVars, InferenceRule ir, int opNum,
772             Order chosenOne,
773             boolean returnBest) {
774 
775         out.println("Variables: " + allVars);
776         TrialDataGroup vDataGroup = this.dataRepository.getVariableDataGroup(ir, allVars);
777         TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(ir,allVars);
778         TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(ir,allVars);
779         
780         // Build instances based on the experimental data.
781         TrialInstances vData, aData, dData;
782         vData = vDataGroup.getTrialInstances();
783         aData = aDataGroup.getTrialInstances();
784         dData = dDataGroup.getTrialInstances();
785 /* 
786         TrialInstances vTest = dataRepository.buildVarInstances(ir, allVars);
787 
788         Assert._assert(vData.numInstances() == vTest.numInstances(),"vGot " + vData.numInstances() + " Wanted: " + vTest.numInstances());
789         TrialInstances aTest = dataRepository.buildAttribInstances(ir, allVars);
790   
791         Assert._assert(aData.numInstances() == aTest.numInstances(), "aGot: " + aData.numInstances() + " Wanted: " + aTest.numInstances());
792     
793         TrialInstances dTest =dataRepository.buildDomainInstances(ir, allVars);
794       
795         Assert._assert(dData.numInstances() == dTest.numInstances(), "dGot: " + dData.numInstances() + " Wanted: " + dTest.numInstances());
796         out.println(aData);
797         out.println(vData);
798         out.println(dData);
799 */
800         // Readjust the weights using an exponential decay factor.
801         adjustWeights(vData, aData, dData);
802         Discretization vDis = null, aDis = null, dDis = null;
803 
804         /*
805        // Discretize the experimental data.  null if there is no data.
806         if (DISCRETIZE1) vDis = vData.discretize(.5);
807         if (DISCRETIZE2) aDis = aData.discretize(.25);
808         if (DISCRETIZE3) dDis = dData.threshold(DOMAIN_THRESHOLD);
809  */
810         vDis = vDataGroup.discretize(.5);
811         aDis = aDataGroup.discretize(.25);
812         dDis = dDataGroup.threshold(DOMAIN_THRESHOLD);
813         // Calculate the accuracy of each classifier using cv folds.
814         long vCTime = System.currentTimeMillis();
815         double vConstCV = -1;//constFoldCV(vData, CLASSIFIER1);
816         vCTime = System.currentTimeMillis() - vCTime;
817 
818         long aCTime = System.currentTimeMillis();
819         double aConstCV = -1;//constFoldCV(aData, CLASSIFIER2);
820         aCTime = System.currentTimeMillis() - aCTime;
821         
822         long dCTime = System.currentTimeMillis();
823         double dConstCV = -1;//constFoldCV(dData, CLASSIFIER3);
824         dCTime = System.currentTimeMillis() - dCTime;
825         
826         long vLTime = System.currentTimeMillis();
827         double vLeaveCV = -1; //leaveOneOutCV(vData, CLASSIFIER1);
828         vLTime = System.currentTimeMillis() - vLTime;
829         
830         long aLTime = System.currentTimeMillis();
831         double aLeaveCV = -1; //leaveOneOutCV(aData, CLASSIFIER2);
832         aLTime = System.currentTimeMillis() - aLTime;
833         
834         long dLTime = System.currentTimeMillis();
835         double dLeaveCV = -1; //leaveOneOutCV(dData, CLASSIFIER3);
836         dLTime = System.currentTimeMillis() - dLTime;
837         
838         if (TRACE > 1) {
839             out.println(" Var data points: " + vData.numInstances());
840             //out.println(" Var Classifier " + NUM_CV_FOLDS + " fold CV Score: " + vConstCV + " took " + vCTime + " ms");
841            // out.println(" Var Classifier leave one out CV Score: " + vLeaveCV + " took " + vLTime + " ms");
842             out.println(" Var Classifier Weight: " + varClassWeight);
843             //out.println(" Var data points: "+vData);
844             out.println(" Attrib data points: " + aData.numInstances());
845            // out.println(" Attrib Classifier " + NUM_CV_FOLDS + " fold CV Score : " + aConstCV + " took " + aCTime + " ms");
846             //out.println(" Attrib Classifier leave one out CV Score: " + aLeaveCV + " took " + aLTime + " ms");
847             out.println(" Attrib Classifier Weight: " + attrClassWeight);
848             //out.println(" Attrib data points: "+aData);
849             out.println(" Domain data points: " + dData.numInstances());
850             //out.println(" Domain Classifier " + NUM_CV_FOLDS + " fold CV Score: " + dConstCV + " took " + dCTime + " ms");
851             //out.println(" Attrib Classifier leave one out CV Score: " + dLeaveCV + " took " + dLTime + " ms");
852             out.println(" Domain Classifier Weight: " + domClassWeight);
853             //out.println(" Domain data points: "+dData);
854 
855         }
856 
857         Classifier vClassifier = null, aClassifier = null, dClassifier = null;
858         // Build the classifiers.
859    /*    
860         if (vData.numInstances() > 0)
861             vClassifier = WekaInterface.buildClassifier(CLASSIFIER1, vData);
862         if (aData.numInstances() > 0)
863             aClassifier = WekaInterface.buildClassifier(CLASSIFIER2, aData);
864         if (dData.numInstances() > 0)
865             dClassifier = WekaInterface.buildClassifier(CLASSIFIER3, dData);
866  */
867         vClassifier = vDataGroup.classify();
868         aClassifier = aDataGroup.classify();
869         dClassifier = dDataGroup.classify();
870        
871         
872         if (DUMP_CLASSIFIER_INFO) {
873             String baseName = solver.getBaseName()+"_rule"+ir.id;
874             if (vClassifier != null)
875                 dumpClassifierInfo(baseName+"_vclassifier", vClassifier, vData);
876             if (aClassifier != null)
877                 dumpClassifierInfo(baseName+"_aclassifier", aClassifier, aData);
878             if (dClassifier != null)
879                 dumpClassifierInfo(baseName+"_dclassifier", dClassifier, dData);
880             try {
881                 out_t = new PrintStream(new FileOutputStream(baseName+"_trials"));
882             } catch (IOException x) {
883                 solver.err.println("Error while opening file: "+x);
884             }
885         } else {
886             out_t = null;
887         }
888         
889         if (TRACE > 2) {
890             out.println("Var classifier: " + vClassifier);
891             out.println("Attrib classifier: " + aClassifier);
892             out.println("Domain classifier: " + dClassifier);
893         }
894 
895        double [][] bucketmeans = getBucketMeans(vDis, aDis, dDis);
896 
897        Collection sel = null;
898        Collection candidates = null;
899         if(chosenOne == null){
900             Collection triedOrders = returnBest ? new LinkedList() : getTriedOrders((BDDInferenceRule) ir, opNum);
901             if(ec != null){
902                 triedOrders.addAll(ec.trials.keySet());
903               
904             }
905             Object object = generateCandidateSet( ir, allVars, bucketmeans, 
906                     vDataGroup, aDataGroup, dDataGroup,
907                     triedOrders, returnBest);
908             /*vClassifier,
909             aClassifier, dClassifier, vData,
910             aData, dData, vDis, aDis,
911             dDis,*/
912             if(object == null) return null;
913             else if(object instanceof Collection)
914                 candidates = (Collection) object;
915             else if(object instanceof TrialGuess)
916                 return (TrialGuess) object;
917         }else {
918             sel = Collections.singleton(chosenOne);
919         }
920         boolean force = (ec != null && ec.getNumTrials() < 2) ||
921             vData.numInstances() < INITIAL_VAR_SET ||
922             aData.numInstances() < INITIAL_ATTRIB_SET ||
923             dData.numInstances() < INITIAL_DOM_SET;
924         
925         if (!returnBest)
926             sel = selectOrder(candidates, vData, aData, dData, ir, force);
927         
928         if (sel == null || sel.isEmpty()) return null;
929         Order o_v = (Order) sel.iterator().next();
930         try {
931             OrderTranslator v2a = new VarToAttribTranslator(ir);
932             OrderTranslator a2d = AttribToDomainTranslator.INSTANCE;
933             double vClass = 0, aClass = 0, dClass = 0;
934             if (vClassifier != null) {
935                 OrderInstance vInst = OrderInstance.construct(o_v, vData);
936                 vClass = vClassifier.classifyInstance(vInst);
937             }
938             Order o_a = v2a.translate(o_v);
939             if (aClassifier != null) {
940                 OrderInstance aInst = OrderInstance.construct(o_a, aData);
941                 aClass = aClassifier.classifyInstance(aInst);
942             }
943             Order o_d = a2d.translate(o_a);
944             if (dClassifier != null) {
945                 OrderInstance dInst = OrderInstance.construct(o_d, dData);
946                 dClass = dClassifier.classifyInstance(dInst);
947             }
948             int vi = (int) vClass, ai = (int) aClass, di = (int) dClass;
949             double vScore = 0, aScore = 0, dScore = 0;
950             if (vi < bucketmeans[VMEAN_INDEX].length) vScore = bucketmeans[VMEAN_INDEX][vi];
951             if (ai < bucketmeans[AMEAN_INDEX].length) aScore = bucketmeans[AMEAN_INDEX][ai];
952             if (di < bucketmeans[DMEAN_INDEX].length) dScore = bucketmeans[DMEAN_INDEX][di];
953             double score = varClassWeight * vScore;
954             score += attrClassWeight * aScore;
955             score += domClassWeight * dScore;
956             return genGuess(o_v, score, vClass, aClass, dClass, vDis, aDis, dDis);
957         } catch (Exception x) {
958             x.printStackTrace();
959             Assert.UNREACHABLE(x.toString());
960             return null;
961         }
962     }
963     
964     public final static int VMEAN_INDEX = 0;
965     public final static int AMEAN_INDEX = 1;
966     public final static int DMEAN_INDEX = 2;
967     public double[][] getBucketMeans(Discretization vDis, Discretization aDis, Discretization dDis){
968         // Calculate the mean value of each of the discretized buckets.
969 
970         double[] vBucketMeans = new double[vDis == null ? 0 : vDis.buckets.length];
971         double[] aBucketMeans = new double[aDis == null ? 0 : aDis.buckets.length];
972         double[] dBucketMeans = new double[dDis == null ? 0 : dDis.buckets.length];
973         if(TRACE > 2) out.print("Var Bucket Means: ");
974         for (int i = 0; i < vBucketMeans.length; ++i) {
975             if (vDis.buckets[i].numInstances() == 0)
976                 vBucketMeans[i] = Double.MAX_VALUE;
977             else
978                 vBucketMeans[i] = vDis.buckets[i].meanOrMode(vDis.buckets[i].classIndex());
979             if(TRACE > 2) out.print(vBucketMeans[i] + " ");
980         }
981         if (TRACE > 2) {
982             out.println();
983             out.print("Attr Bucket Means: ");
984         }
985         for (int i = 0; i < aBucketMeans.length; ++i) {
986             if (aDis.buckets[i].numInstances() == 0)
987                 aBucketMeans[i] = Double.MAX_VALUE;
988             else
989                 aBucketMeans[i] = aDis.buckets[i].meanOrMode(aDis.buckets[i].classIndex());
990             if(TRACE > 2) out.print(aBucketMeans[i] + " ");
991         }
992         if (TRACE > 2) {
993             out.println();
994             out.print("Domain Bucket Means: ");
995         }
996         for (int i = 0; i < dBucketMeans.length; ++i) {
997             if (dDis.buckets[i].numInstances() == 0)
998                 dBucketMeans[i] = Double.MAX_VALUE;
999             else
1000                 dBucketMeans[i] = dDis.buckets[i].meanOrMode(dDis.buckets[i].classIndex());
1001             if(TRACE > 2) out.print(dBucketMeans[i] + " ");
1002         }
1003         if(TRACE > 2) out.println();
1004         double [][] means = new double[3][];
1005         
1006         means[VMEAN_INDEX] = vBucketMeans;
1007         means[AMEAN_INDEX] = aBucketMeans;
1008         means[DMEAN_INDEX] = dBucketMeans;
1009         
1010         return means;
1011     }
1012     
1013     public int getCombos(double[][] combos, int start, int numV, int numA, int numD,
1014             double vBuckets, double aBuckets, double dBuckets,
1015             double [][] means,
1016             double maxScore){
1017         double [] vBucketMeans = means[VMEAN_INDEX];
1018         double [] aBucketMeans = means[AMEAN_INDEX];
1019         double [] dBucketMeans = means[DMEAN_INDEX];
1020         
1021         int p = 0;
1022         for (int vi = start; vi < numV; ++vi) {
1023             for (int ai = start; ai < numA; ++ai) {
1024                 for (int di = 0; di < numD; ++di) { // don't do nulls for domain classifier.
1025                     double vScore, aScore, dScore;
1026                     double nullScore = 1;
1027                     if (vi == -1) vScore = nullScore;
1028                     else if (vi < vBuckets && vi < vBucketMeans.length) vScore = vBucketMeans[vi];
1029                     else vScore = maxScore;
1030                     if (ai == -1) aScore = nullScore;
1031                     else if (ai < aBuckets && ai < aBucketMeans.length) aScore = aBucketMeans[ai];
1032                     else aScore = maxScore;
1033                     if (di == -1) dScore = nullScore;
1034                     else if (di < dBuckets && di < dBucketMeans.length) dScore = dBucketMeans[di];
1035                     else dScore = maxScore;
1036                     double score = varClassWeight * vScore;
1037                     score += attrClassWeight * aScore;
1038                     score += domClassWeight * dScore;
1039                     double[] result = new double[] { score, vi==-1?Double.NaN:vi,
1040                                                             ai==-1?Double.NaN:ai,
1041                                                             di==-1?Double.NaN:di };
1042                     if (TRACE > 2) {
1043                         out.println("Score for v="+vi+" a="+ai+" d="+di+": "+format(score));
1044                     }
1045                     combos[p++] = result;
1046                 }
1047             }
1048         }
1049         Arrays.sort(combos, 0, p, new Comparator() {
1050             public int compare(Object arg0, Object arg1) {
1051                 double[] a = (double[]) arg0;
1052                 double[] b = (double[]) arg1;
1053                 return FindBestDomainOrder.compare(a[0], b[0]);
1054             }
1055         });
1056         
1057         return p;
1058     }
1059 
1060     
1061     public Object generateCandidateSet(InferenceRule ir, List allVars, double [][] means, /* double [] vBucketMeans,
1062             double[] aBucketMeans, double [] dBucketMeans, */TrialDataGroup vDataGroup,
1063             TrialDataGroup aDataGroup, TrialDataGroup dDataGroup, Collection triedOrders, boolean returnBest){
1064         
1065         double [] vBucketMeans = means[VMEAN_INDEX];
1066         double [] aBucketMeans = means[AMEAN_INDEX];
1067         double [] dBucketMeans = means[DMEAN_INDEX];
1068         
1069         // Build multi-map from attributes/domains to variables.
1070         MultiMap a2v, d2v;
1071         a2v = new GenericMultiMap();
1072         d2v = new GenericMultiMap();
1073         for (Iterator i = allVars.iterator(); i.hasNext();) {
1074             Variable v = (Variable) i.next();
1075             Attribute a = (Attribute) ir.getAttribute(v);
1076             if (a != null) {
1077                 a2v.add(a, v);
1078                 d2v.add(a.getDomain(), v);
1079             }
1080         }
1081         
1082         Set candidates = null;
1083             boolean addNullValues = !returnBest;
1084             // Grab the best from the classifiers and try to build an optimal order.
1085             if (!returnBest) candidates = new LinkedHashSet();
1086             Collection never = neverAgain.getValues(ir);
1087             //MyId3 v = (MyId3) vClassifier, a = (MyId3) aClassifier, d = (MyId3) dClassifier;
1088             Discretization vDis = vDataGroup.getDiscretization();
1089             Discretization aDis = aDataGroup.getDiscretization();
1090             Discretization dDis = dDataGroup.getDiscretization();
1091             
1092             int end = 5;
1093             // Use only top half of buckets.
1094             int vBuckets = vDis == null ? 1 : vDis.buckets.length / 2 + 1;
1095             int aBuckets = aDis == null ? 1 : aDis.buckets.length / 2 + 1;
1096             int dBuckets = dDis == null ? 1 : dDis.buckets.length / 2 + 1;
1097             double max = (vBucketMeans.length != 0 ? vBucketMeans[vBucketMeans.length-1] : 0);
1098             max += (aBucketMeans.length != 0 ? aBucketMeans[aBucketMeans.length-1] : 0);
1099             max += (dBucketMeans.length != 0 ? dBucketMeans[dBucketMeans.length-1] : 0);
1100             boolean[][][] done = new boolean[vBuckets+1][aBuckets+1][dBuckets+1];
1101         outermost:
1102             while (candidates == null || candidates.size() < CANDIDATE_SET_SIZE) {
1103                 int numV = Math.min(end, vBuckets);
1104                 int numA = Math.min(end, aBuckets);
1105                 int numD = Math.min(end, dBuckets);
1106                 if (true && end > vBuckets && end > aBuckets && end > dBuckets) {
1107                     // Also include the "empty" classification for all of them.
1108                     numV++; numA++; numD++;
1109                 }
1110                 int maxNum = addNullValues ? ((numV+1)*(numA+1)*(numD+1)) : (numV*numA*numD);
1111                 double[][] combos = new double[maxNum][];
1112                 int start = addNullValues ? -1 : 0; 
1113                 int p = getCombos(combos,start,numV, numA, numD, vBuckets, aBuckets, dBuckets,
1114                         means, max);
1115               
1116                 for (int z = 0; z < p; ++z) {
1117                     double bestScore = combos[z][0];
1118                     double vClass = combos[z][1]; int vi = (int) vClass;
1119                     double aClass = combos[z][2]; int ai = (int) aClass;
1120                     double dClass = combos[z][3]; int di = (int) dClass;
1121                     // If one of them reaches the highest index, we need to break.
1122                     if (vi == numV-1 && end <= vBuckets ||
1123                         ai == numA-1 && end <= aBuckets ||
1124                         di == numD-1 && end <= dBuckets) {
1125                         if (TRACE > 1) out.println("reached end ("+vi+","+ai+","+di+"), trying again with a higher cutoff.");
1126                         break;
1127                     }
1128                     if (!Double.isNaN(vClass) && !Double.isNaN(aClass) && !Double.isNaN(dClass)) {
1129                         if (done[vi][ai][di]) continue;
1130                         done[vi][ai][di] = true;
1131                     } else {
1132                         addNullValues = false;
1133                     }
1134                     if (vi == vBuckets) vClass = -1; // any
1135                     if (ai == aBuckets) aClass = -1;
1136                     if (di == dBuckets) dClass = -1;
1137                     if (out_t != null) out_t.println("v="+vClass+" a="+aClass+" d="+dClass+": "+format(bestScore));
1138                     Collection ocss = tryConstraints(vDataGroup, vClass, aDataGroup, aClass, dDataGroup, dClass, a2v, d2v);
1139                         //tryConstraints(v, vClass, vData, a, aClass, aData, d, dClass, dData, a2v, d2v);
1140                     if (ocss == null || ocss.isEmpty()) {
1141                         if (out_t != null) out_t.println("Constraints cannot be combined.");
1142                         continue;
1143                     }
1144                    // Collection triedOrders = getTriedOrders((BDDInferenceRule)ir);
1145                     for (Iterator i = ocss.iterator(); i.hasNext(); ) {
1146                         OrderConstraintSet ocs = (OrderConstraintSet) i.next();
1147                         if (out_t != null) out_t.println("Constraints: "+ocs);
1148                         if (returnBest) {
1149                             TrialGuess guess = genGuess(ocs, bestScore, allVars, bestScore, vClass, aClass, dClass,
1150                                 vDis, aDis, dDis, triedOrders, /*tc,*/ never);
1151                             if (guess != null) {
1152                                 if (TRACE > 1) out.println("Best Guess: "+guess);
1153                                 return guess;
1154                             }
1155                         } else {
1156                             // Add these orders to the collection.
1157                             //genOrders(ocs, allVars, tc == null ? null : tc.trials.keySet(), never, candidates);
1158                             genOrders(ocs, allVars, triedOrders, never, candidates);
1159                             if (candidates.size() >= CANDIDATE_SET_SIZE) break outermost;
1160                         }
1161                     }
1162                 }
1163                 if (end > vBuckets && end > aBuckets && end > dBuckets) {
1164                     if (TRACE > 1) out.println("Reached end, no more possible guesses!");
1165                  /*   if (false) {
1166                         // TODO: we can do something better here!
1167                         OrderIterator i = new OrderIterator(allVars);
1168                         while (i.hasNext()) {
1169                             Order o_v = i.nextOrder();
1170                             if (tc != null && tc.contains(o_v)) continue;
1171                             if (never != null && never.contains(o_v)) continue;
1172                             if (TRACE > 1) out.println("Just trying "+o_v);
1173                             if (returnBest) {
1174                                 sel = Collections.singleton(o_v);
1175                                 break outermost;
1176                             } else {
1177                                 // Add this order to the collection.
1178                                 if (TRACE > 1) out.println("Adding to candidate set: "+o_v);
1179                                 candidates.add(o_v);
1180                                 if (candidates.size() >= CANDIDATE_SET_SIZE) break outermost;
1181                             }
1182                         }
1183                     }
1184                     */
1185                     if (returnBest) {
1186                         return null;
1187                     }
1188                     break outermost;
1189                 }
1190                 end *= 2;
1191                 if (TRACE > 1) out.println("Cutoff is now "+end);
1192             }
1193         
1194         
1195         return candidates;
1196     }
1197     public static int CANDIDATE_SET_SIZE = Integer.parseInt(SystemProperties.getProperty("candidateset", "500"));
1198     public static int SAMPLE_SIZE = 1;
1199     public static double UNCERTAINTY_THRESHOLD = Double.parseDouble(SystemProperties.getProperty("uncertainty", ".25"));
1200     public static boolean WEIGHT_UNCERTAINTY_SAMPLE = false;
1201     public static double VCENT = .5, ACENT = .5, DCENT = 1;
1202     static CandidateSampler candidateSetSampler = new UncertaintySampler(SAMPLE_SIZE, UNCERTAINTY_THRESHOLD, VCENT, ACENT, DCENT);
1203     
1204     public Collection selectOrder(Collection orders,
1205             TrialInstances vData, TrialInstances aData, TrialInstances dData, InferenceRule ir, boolean force) {
1206         Assert._assert(orders != null); //catch error if happens
1207         if(orders.size() == 0){
1208             if(TRACE > 1) out.println("Size of candidate set is 0. No orders to select from");
1209             return null;
1210         }
1211         if (TRACE > 1) out.println("Selecting an order from a candidate set of "+orders.size()+" orders.");
1212         if (TRACE > 2) out.println("Orders: "+orders);
1213        
1214         return candidateSetSampler.sample(orders, vData, aData, dData, ir, force);  
1215     }
1216     
1217 
1218     
1219     /***
1220      * Returns all the orders that have been tried on a particular rule update
1221      * (including those in previous runs).
1222      * 
1223      * @param rule
1224      * @return  collection of tried orders
1225      */
1226     Collection getTriedOrders(BDDInferenceRule rule, int opNumber){
1227         Collection triedOrders = new LinkedList();
1228       for(Iterator it = allTrials.iterator(); it.hasNext(); ){
1229           EpisodeCollection ec = (EpisodeCollection) it.next();
1230           if(ec.getRule(solver) == rule && ec.getUpdateCount() == rule.updateCount && ec.getOpNumber() == opNumber)
1231               triedOrders.addAll(ec.trials.keySet());
1232       }
1233       if(TRACE > 2) out.println("Tried Orders: " + triedOrders);
1234       return triedOrders;  
1235     }
1236     static void genOrders(OrderConstraintSet ocs, List allVars, Collection already, Collection never, Collection result) {
1237         if (out_t != null) out_t.println("Generating orders for "+allVars);
1238         List orders;
1239         int nOrders = ocs.approxNumOrders(allVars.size());
1240         if (nOrders > CANDIDATE_SET_SIZE*20) {
1241             if (out_t != null) out_t.println("Too many possible orders ("+nOrders+")!  Using random sampling.");
1242             orders = new LinkedList();
1243             for (int i = 0; i < CANDIDATE_SET_SIZE; ++i) {
1244                 orders.add(ocs.generateRandomOrder(allVars));
1245             }
1246         } else {
1247             if (out_t != null) out_t.println("Estimated "+nOrders+" orders.");
1248             orders = ocs.generateAllOrders(allVars);
1249         }
1250         for (Iterator m = orders.iterator(); m.hasNext(); ) {
1251             Order best = (Order) m.next();
1252             if (never.contains(best)) {
1253                 if (out_t != null) out_t.println("Skipped order "+best+" because it has blown up before.");
1254                 continue;
1255             }
1256             if (already == null || !already.contains(best)) {
1257                 if (out_t != null) out_t.println("Adding to candidate set: "+best);
1258                 result.add(best);
1259                 if (result.size() > CANDIDATE_SET_SIZE) {
1260                     if (out_t != null) out_t.println("Candidate set full.");
1261                     return;
1262                 }
1263             } else {
1264                 if (out_t != null) out_t.println("We have already tried order "+best);
1265             }
1266         }
1267     }
1268      
1269     Collection /*Pair*/ genConstaints( int num, MultiMap a2v, MultiMap d2v,InferenceRule ir,TrialDataGroup vDataGroup,
1270             TrialDataGroup aDataGroup,  TrialDataGroup dDataGroup){
1271         Discretization  vDis = vDataGroup.getDiscretization();
1272         Discretization aDis = aDataGroup.getDiscretization();
1273         Discretization dDis = dDataGroup.getDiscretization();
1274  
1275         double [][] means = getBucketMeans(vDis,aDis,dDis);
1276         int numVBuckets = means[VMEAN_INDEX].length;
1277         int numABuckets = means[AMEAN_INDEX].length;
1278         int numDBuckets = means[DMEAN_INDEX].length;
1279         int maxNum = numVBuckets * numABuckets * numDBuckets;
1280         double [][] combos = new double[maxNum][];
1281         double [] vBucketMeans = means[VMEAN_INDEX]; 
1282         double [] aBucketMeans = means[AMEAN_INDEX];
1283         double [] dBucketMeans = means[DMEAN_INDEX];
1284         double maxScore = (vBucketMeans.length != 0 ? vBucketMeans[vBucketMeans.length-1] : 0);
1285         maxScore += (aBucketMeans.length != 0 ? aBucketMeans[aBucketMeans.length-1] : 0);
1286         maxScore += (dBucketMeans.length != 0 ? dBucketMeans[dBucketMeans.length-1] : 0);
1287         int numCombos = getCombos(combos, 0, numVBuckets, numABuckets, numDBuckets,
1288                           numVBuckets, numABuckets, numDBuckets, means, maxScore);
1289         
1290         Collection allPairs = new LinkedList();
1291         int numAdded = 0;
1292         for(int i = 0; i < numCombos && numAdded < num; ++i){
1293            double [] combo = combos[i];
1294            Collection constraints = tryConstraints(vDataGroup, combo[1], aDataGroup, combo[2], dDataGroup,combo[3],a2v,d2v);
1295            if(constraints == null) continue;
1296            for(Iterator jt = constraints.iterator(); jt.hasNext(); ){
1297                allPairs.add(new Pair(new Double(combo[0]) , jt.next()));
1298                ++numAdded;
1299            }
1300           
1301         }
1302         
1303         return allPairs;
1304     }
1305   
1306     static TrialGuess genGuess(Order best, double score,
1307             double vClass, double aClass, double dClass,
1308             Discretization vDis, Discretization aDis, Discretization dDis) {
1309         double vLowerBound, vUpperBound, aLowerBound, aUpperBound, dLowerBound, dUpperBound;
1310         vLowerBound = vUpperBound = aLowerBound = aUpperBound = dLowerBound = dUpperBound = -1;
1311 
1312         if (vDis != null && !Double.isNaN(vClass) && vClass != NO_CLASS) {
1313             vLowerBound = vDis.cutPoints == null || vClass <= 0 ? 0 : vDis.cutPoints[(int) vClass - 1];
1314             vUpperBound = vDis.cutPoints == null || vClass == vDis.cutPoints.length ? Double.MAX_VALUE : vDis.cutPoints[(int) vClass];
1315         }
1316         if (aDis != null && !Double.isNaN(aClass) && aClass != NO_CLASS) {
1317             aLowerBound = aDis.cutPoints == null || aClass <= 0 ? 0 : aDis.cutPoints[(int) aClass - 1];
1318             aUpperBound = aDis.cutPoints == null || aClass == aDis.cutPoints.length ? Double.MAX_VALUE : aDis.cutPoints[(int) aClass];
1319         }
1320         if (dDis != null && !Double.isNaN(dClass) && dClass != NO_CLASS) {
1321             dLowerBound = dDis.cutPoints == null || dClass <= 0 ? 0 : dDis.cutPoints[(int) dClass - 1];
1322             dUpperBound = dDis.cutPoints != null || dClass == dDis.cutPoints.length ? Double.MAX_VALUE : dDis.cutPoints[(int) dClass];
1323         }
1324         TrialPrediction prediction = new TrialPrediction(score, vLowerBound,vUpperBound,aLowerBound, aUpperBound,dLowerBound,dUpperBound);
1325         return new TrialGuess(best, prediction);
1326     }
1327     
1328     static TrialGuess genGuess(OrderConstraintSet ocs, double score, List allVars, double bestScore,
1329         double vClass, double aClass, double dClass,
1330         Discretization vDis, Discretization aDis, Discretization dDis,
1331         /* EpisodeCollection tc,*/ Collection triedOrders, Collection never) {
1332         if (out_t != null) out_t.println("Generating orders for "+allVars);
1333         // Choose a random one first.
1334         Order best = ocs.generateRandomOrder(allVars);
1335         Iterator m = Collections.singleton(best).iterator();
1336         boolean exhaustive = true;
1337         while (m.hasNext()) {
1338             best = (Order) m.next();
1339             if (never.contains(best)) {
1340                 if (out_t != null) out_t.println("Skipped order "+best+" because it has blown up before.");
1341                 continue;
1342             }   
1343                 
1344             //if (tc == null || !tc.contains(best)) {
1345             if(!triedOrders.contains(best)){
1346                 if (out_t != null) out_t.println("Using order "+best);
1347                 return genGuess(best, score, vClass, aClass, dClass, vDis, aDis, dDis);
1348             } else {
1349                 if (out_t != null) out.println("We have already tried order "+best);
1350             }
1351             if (exhaustive) {
1352                 List orders = ocs.generateAllOrders(allVars);
1353                 m = orders.iterator();
1354                 exhaustive = false;
1355             }
1356        }
1357         return null;
1358     }
1359 
1360     static Collection/*OrderConstraintSet*/ tryConstraints(
1361             /* MyId3 v, double vClass, Instances vData,
1362             MyId3 a, double aClass, Instances aData,
1363             MyId3 d, double dClass, Instances dData,
1364             */
1365             TrialDataGroup vDataGroup, double vClass,
1366             TrialDataGroup aDataGroup, double aClass,
1367             TrialDataGroup dDataGroup, double dClass,
1368             MultiMap a2v, MultiMap d2v) {
1369         Collection results = new LinkedList();
1370         Instances vData = vDataGroup.getTrialInstances();
1371         Instances aData = aDataGroup.getTrialInstances();
1372         Instances dData = dDataGroup.getTrialInstances();
1373         MyId3 v = (MyId3) vDataGroup.getClassifier();
1374         MyId3 a = (MyId3) aDataGroup.getClassifier();
1375         MyId3 d = (MyId3) dDataGroup.getClassifier();
1376         Collection vBestAttribs;
1377         if ((vClass >= 0 || Double.isNaN(vClass)) && v != null)
1378             vBestAttribs = v.getAttribCombos(vData.numAttributes(), vClass);
1379         else
1380             vBestAttribs = Collections.singleton(makeEmptyConstraint());
1381         if (vBestAttribs == null) return null;
1382         for (Iterator v_i = vBestAttribs.iterator(); v_i.hasNext(); ) {
1383             double[] v_c = (double[]) v_i.next();
1384             OrderConstraintSet ocs = new OrderConstraintSet();
1385             boolean v_r = constrainOrder(ocs, v_c, vData, null);
1386             if (!v_r) {
1387                 continue;
1388             }
1389             if (out_t != null) out_t.println(" Order constraints (var="+(int)vClass+"): "+ocs);
1390 
1391             Collection aBestAttribs;
1392             if ((aClass >= 0 || Double.isNaN(aClass)) && a != null)
1393                 aBestAttribs = a.getAttribCombos(aData.numAttributes(), aClass);
1394             else
1395                 aBestAttribs = Collections.singleton(makeEmptyConstraint());
1396             if (aBestAttribs == null) continue;
1397             for (Iterator a_i = aBestAttribs.iterator(); a_i.hasNext(); ) {
1398                 double[] a_c = (double[]) a_i.next();
1399                 OrderConstraintSet ocsBackup = null;
1400                 if (a_i.hasNext()) ocsBackup = ocs.copy();
1401                 boolean a_r = constrainOrder(ocs, a_c, aData, a2v);
1402                 if (!a_r) {
1403                     ocs = ocsBackup;
1404                     continue;
1405                 }
1406                 if (out_t != null) out_t.println("  Order constraints (attrib="+(int)aClass+"): "+ocs);
1407 
1408                 Collection dBestAttribs;
1409                 if ((dClass >= 0 || Double.isNaN(dClass)) && d != null)
1410                     dBestAttribs = d.getAttribCombos(dData.numAttributes(), dClass);
1411                 else
1412                     dBestAttribs = Collections.singleton(makeEmptyConstraint());
1413                 if (dBestAttribs != null) {
1414                     for (Iterator d_i = dBestAttribs.iterator(); d_i.hasNext(); ) {
1415                         double[] d_c = (double[]) d_i.next();
1416                         OrderConstraintSet ocsBackup2 = null;
1417                         if (d_i.hasNext()) ocsBackup2 = ocs.copy();
1418                         boolean d_r = constrainOrder(ocs, d_c, dData, d2v);
1419                         if (d_r) {
1420                             if (out_t != null) out_t.println("   Order constraints (domain="+(int)dClass+"): "+ocs);
1421                             results.add(ocs);
1422                         }
1423                         ocs = ocsBackup2;
1424                     }
1425                 }
1426                 ocs = ocsBackup;
1427             }
1428         }
1429         return results;
1430     }
1431 
1432     static double computeScore(int vC, int aC, int dC,
1433         double[] vMeans, double[] aMeans, double[] dMeans,
1434         double vWeight, double aWeight, double dWeight) {
1435         double score = vMeans[vC] * vWeight;
1436         score += aMeans[aC] * aWeight;
1437         score += dMeans[dC] * dWeight;
1438         return score;
1439     }
1440 
1441     static double[] makeEmptyConstraint() {
1442         int size = 0;
1443         double[] d = new double[size];
1444         for (int i = 0; i < d.length; ++i) {
1445             d[i] = Double.NaN;
1446         }
1447         return d;
1448     }
1449     
1450     static boolean constrainOrder(OrderConstraintSet ocs, double[] c, Instances data, MultiMap map) {
1451         for (int iii = 0; iii < c.length; ++iii) {
1452             if (Double.isNaN(c[iii])) continue;
1453             int k = (int) c[iii];
1454             OrderAttribute oa = (OrderAttribute) data.attribute(iii);
1455             OrderConstraint oc = oa.getConstraint(k);
1456             if (map != null) {
1457                 Collection c1 = map.getValues(oc.getFirst());
1458                 Collection c2 = map.getValues(oc.getSecond());
1459                 boolean any = false;
1460                 for (Iterator ii = c1.iterator(); ii.hasNext();) {
1461                     Object a = ii.next();
1462                     for (Iterator jj = c2.iterator(); jj.hasNext();) {
1463                         Object b = jj.next();
1464                         if(a.equals(b)) continue;
1465                         OrderConstraint cc = OrderConstraint.makeConstraint(oc.getType(), a, b);
1466                         boolean r = ocs.constrain(cc);
1467                         if (r) {
1468                             any = true;
1469                         }
1470                     }
1471                 }
1472                 if (!any) {
1473                     if (TRACE > 3) out.println("Constraint "+oc+" conflicts with "+ocs);
1474                     return false;
1475                 }
1476             } else {
1477                 boolean r = ocs.constrain(oc);
1478                 if (!r) {
1479                     if (TRACE > 3) out.println("Constraint "+oc+" conflicts with "+ocs);
1480                     return false;
1481                 }
1482             }
1483         }
1484         return true;
1485     }
1486 
1487     void printGoodOrder(Collection allVars, Instances inst, MyId3 v) {
1488         Collection vBestAttribs = v.getAttribCombos(inst.numAttributes(), 0.);
1489         if (vBestAttribs != null) {
1490             outer:
1491                 for (Iterator ii = vBestAttribs.iterator(); ii.hasNext(); ) {
1492                 double[] c = (double[]) ii.next();
1493                 OrderConstraintSet ocs = new OrderConstraintSet();
1494                 for (int iii = 0; iii < c.length; ++iii) {
1495                     if (Double.isNaN(c[iii])) continue;
1496                     int k = (int) c[iii];
1497                     OrderAttribute oa = (OrderAttribute) inst.attribute(iii);
1498                     out.println(oa);
1499                     OrderConstraint oc = oa.getConstraint(k);
1500                     out.println(oc);
1501                     boolean r = ocs.constrain(oc);
1502                     if (!r) {
1503                         if (TRACE > 1) out.println("Constraint "+oc+" conflicts with "+ocs);
1504                         continue outer;
1505                     }
1506                 }
1507                 Order o = ocs.generateRandomOrder(allVars);
1508                 out.println("Good order: " + o);
1509             }
1510         }
1511     }
1512 
1513     TrialGuess evalOrder(Order o, InferenceRule ir) {
1514         List allVars = o.getFlattened();
1515         return tryNewGoodOrder(null, allVars, ir, -2,  o, false);
1516     }
1517     
1518     public static class OrderSearchElem implements Comparable{
1519         public double pathScore;
1520         public double pathCost;
1521 
1522         public OrderConstraintSet ocs;
1523         public int nextRule;
1524         public Collection rulesLeft;
1525         public OrderSearchElem(OrderSearchElem that){
1526             this.pathScore = that.pathScore;
1527             this.ocs = new OrderConstraintSet(that.ocs);
1528             this.nextRule = that.nextRule;
1529             this.pathCost = that.pathCost;
1530         }
1531         public OrderSearchElem(double score, double cost, OrderConstraintSet ocs, int nextRule){
1532             this.pathScore = score;
1533             this.pathCost = cost;
1534             this.ocs = ocs;
1535             this.nextRule= nextRule;
1536         }
1537         
1538         public String toString(){
1539             return "[" + pathScore + ", " + ocs + "]";
1540         }
1541         /* (non-Javadoc)
1542          * @see java.lang.Comparable#compareTo(java.lang.Object)
1543          */
1544         public int compareTo(Object arg0) {
1545             OrderSearchElem that = (OrderSearchElem) arg0;
1546             return Double.compare(this.pathScore, that.pathScore);
1547         }
1548     }
1549         
1550     void cache(int ruleNum, BDDInferenceRule rule, OrderConstraintSet[][] cachedConstraints, double[][] cachedScores){
1551             List vars = new LinkedList(rule.getNecessaryVariables());
1552             Object[] arr = vars.toArray();
1553             Arrays.sort(arr, rule.new VarOrderComparator(solver.VARORDER));
1554             vars = Arrays.asList(arr);
1555 
1556             if(true){
1557                 if(TRACE > 1) out.println("Finding Constraints for: " + rule);
1558             OrderTranslator t = new MapBasedTranslator(rule.variableToBDDDomain);
1559             EpisodeCollection tc = new EpisodeCollection(rule, 0);
1560             
1561             boolean initialized = false;
1562             /* put these in backwards, so we can a stack */
1563             for (int i = 0; i < NUM_BEST_ORDERS_PER_RULE; ++i) {
1564                 if(!initialized){
1565                     cachedConstraints[ruleNum] = new OrderConstraintSet[NUM_BEST_ORDERS_PER_RULE];
1566                     cachedScores[ruleNum] = new double[NUM_BEST_ORDERS_PER_RULE];
1567                     initialized = true;
1568                 }
1569                 TrialGuess tg = tryNewGoodOrder(tc, vars, rule, -2, null, true);
1570                 if (tg == null) break;
1571                 OrderConstraintSet newOcs = new OrderConstraintSet();
1572                 newOcs.constrain(t.translate(tg.order), null);
1573                 cachedConstraints[ruleNum][i] = newOcs;
1574                 cachedScores[ruleNum][i] = tg.prediction.score; // * (rule.totalTime+1) / 1000;
1575                 tc.addTrial(tg.order, null, 0, System.currentTimeMillis());
1576             }
1577         }else{
1578                 MultiMap a2v, d2v;
1579                 a2v = new GenericMultiMap();
1580                 d2v = new GenericMultiMap();
1581                 for (Iterator i = vars.iterator(); i.hasNext();) {
1582                     Variable v = (Variable) i.next();
1583                     Attribute a = (Attribute) rule.getAttribute(v);
1584                     if (a != null) {
1585                         a2v.add(a, v);
1586                         d2v.add(a.getDomain(), v);
1587                     }
1588                 }
1589                 TrialDataGroup vDataGroup = dataRepository.getVariableDataGroup(rule,vars);
1590                 vDataGroup.discretize(.5);
1591                 vDataGroup.classify();
1592                 TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(rule,vars);
1593                 aDataGroup.discretize(.25);
1594                 vDataGroup.classify();
1595                 TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(rule,vars);
1596                 dDataGroup.threshold(2);
1597                 dDataGroup.classify();
1598                 Collection pairs = genConstaints(NUM_BEST_ORDERS_PER_RULE,a2v,d2v,rule,vDataGroup,aDataGroup,dDataGroup);
1599                 cachedConstraints[ruleNum] = new OrderConstraintSet[NUM_BEST_ORDERS_PER_RULE];
1600                 cachedScores[ruleNum] = new double[NUM_BEST_ORDERS_PER_RULE];
1601                 int i = 0;
1602                 for(Iterator it = pairs.iterator(); it.hasNext() && i < NUM_BEST_ORDERS_PER_RULE; ++i){
1603                     Pair pair = (Pair) it.next();
1604                     double score = ((Double) pair.get(0)).doubleValue();
1605                     OrderConstraintSet constraints = (OrderConstraintSet) pair.get(1);
1606                     cachedConstraints[ruleNum][i] = constraints.translate(rule.variableToBDDDomain);
1607                     cachedScores[ruleNum][i] = score;
1608                 }
1609         }
1610     }
1611     
1612     static String MAX_CON_ORDERS = System.getProperty("considertrials");
1613     static int MAX_GEN_ORDERS = 100;
1614     static final int NUM_BEST_ORDERS_PER_RULE = 3;
1615     void myPrintBestBDDOrders(StringBuffer sb, Collection domains,List rules) {
1616        if(rules.size() == 0) return;
1617        Collection visitedElems = new LinkedList();
1618        double [][] cachedScores = new double[rules.size()][] ;
1619        OrderConstraintSet [][] cachedConstraints = new OrderConstraintSet[rules.size()][];
1620        Queue queue = new StackQueue(); // PriorityQueue(); 
1621        int numPrintedOrders = 0;
1622        int nodes = 0, maxQueueSize = 0;
1623        long allRulesTime = 0;
1624        TrialDataRepository repository = MAX_CON_ORDERS == null ? 
1625          dataRepository : dataRepository.reduceByNumTrials(Integer.parseInt(MAX_CON_ORDERS));
1626        
1627        for(Iterator it = rules.iterator(); it.hasNext(); )
1628            allRulesTime += ((BDDInferenceRule) it.next()).totalTime;
1629        
1630        OrderSearchElem first = new OrderSearchElem(0,0, new OrderConstraintSet(),0);
1631        queue.offer(first);
1632        
1633        while(!queue.isEmpty() && numPrintedOrders < MAX_GEN_ORDERS){
1634            maxQueueSize = Math.max(maxQueueSize, queue.size());
1635            OrderSearchElem elem = (OrderSearchElem) queue.poll();
1636            Assert._assert(elem != null);
1637            if(elem.nextRule >= rules.size() || (elem.rulesLeft != null && elem.rulesLeft.isEmpty()) || elem.ocs.onlyOneOrder(domains.size())){
1638             if(TRACE > 1){
1639                 out.println("No more rules or constraints on this path");
1640                 out.println("Generating orders for: " + elem.ocs);
1641             } 
1642              Collection orders;
1643              if (elem.ocs.approxNumOrders(domains.size()) > MAX_GEN_ORDERS) {
1644                  if(TRACE > 1) out.println("More than " + MAX_GEN_ORDERS + " orders. Dumping...random sample");
1645                  orders = new HashSet();
1646                  for (int i = 0; i < 5; ++i) {
1647                      orders.add(elem.ocs.generateRandomOrder(domains));
1648                  }
1649              } else {
1650                  orders = elem.ocs.generateAllOrders(domains);
1651              }
1652              for (Iterator j = orders.iterator(); j.hasNext(); ) {
1653                  Order o = (Order) j.next();
1654                  sb.append("Score "+format(elem.pathScore, 5)+": "+o.toVarOrderString(null));
1655                  sb.append('\n');
1656              }
1657              sb.append("-\n");
1658              numPrintedOrders += orders.size();
1659              continue;
1660            }
1661            ++nodes;
1662           // LinkedList elems = new LinkedList();
1663            if(TRACE > 3) out.println("Expanding: " + elem);
1664            
1665            BDDInferenceRule rule = (BDDInferenceRule) rules.get(elem.nextRule);
1666            boolean cached  = cachedConstraints[elem.nextRule] != null;
1667            if(!cached) cache(elem.nextRule, rule, cachedConstraints, cachedScores);
1668            for (int i = 0; i < NUM_BEST_ORDERS_PER_RULE; ++i) {
1669                OrderConstraintSet constraints = cachedConstraints[elem.nextRule][i];
1670                OrderSearchElem newElem = new OrderSearchElem(elem);
1671                if(constraints == null){
1672                    if(i == 0) {
1673                        ++newElem.nextRule;
1674                        queue.offer(newElem);
1675                        //elems.add(elem);
1676                    }
1677                    break;
1678                }
1679                double constraintScore = cachedScores[elem.nextRule][i];
1680                ++newElem.nextRule;
1681                Collection invalidConstraints = new LinkedList();
1682                if(TRACE > 3) out.println("Adding constraints: " + constraints);
1683                boolean worked = newElem.ocs.constrain(constraints, invalidConstraints);
1684                //newElem.pathCost += (rule.totalTime * constraintScore) * (1 + invalidConstraints.size() / constraints.size()) ;
1685                newElem.pathCost +=  invalidConstraints.size() * (rule.totalTime / constraintScore) ;
1686                newElem.pathScore = newElem.pathCost;
1687                
1688                //if(!worked) continue;//newElem.ocs = backupOcs;
1689                if(TRACE > 3)  out.println("Couldn't add: " + invalidConstraints);
1690                queue.offer(newElem);
1691                //elems.add(newElem);
1692            
1693            }
1694            /*if we're using the stack, push them in reverse priority */
1695       /*     if(elems.size() > 0)
1696            for(ListIterator it = elems.listIterator(elems.size() - 1);  it.hasPrevious();  )
1697                queue.offer(it.previous());   
1698         */     
1699        }
1700        out.println("Max queue size:  " + maxQueueSize + " Nodes expanded: " + nodes);
1701     }
1702 
1703    void printBestBDDOrders(StringBuffer sb, double score, Collection domains, OrderConstraintSet ocs,
1704             MultiMap rulesToTrials, List rules) {
1705         if (rules == null || rules.isEmpty()) {
1706             if(TRACE > 1) out.println("No more rules, Generating orders");
1707             Collection orders;
1708             if (ocs.approxNumOrders(domains.size()) > 1000) {
1709                 if(TRACE > 1) out.println("More than " + MAX_GEN_ORDERS + " orders. Dumping...random sample");
1710                 orders = new LinkedList();
1711                 for (int i = 0; i < 5; ++i) {
1712                     orders.add(ocs.generateRandomOrder(domains));
1713                 }
1714             } else {
1715                 if(TRACE > 1) out.println("Generating orders from constraints: " + ocs);
1716                 orders = ocs.generateAllOrders(domains);
1717             }
1718             for (Iterator j = orders.iterator(); j.hasNext(); ) {
1719                 Order o = (Order) j.next();
1720                 sb.append("Score "+format(score)+": "+o.toVarOrderString(null));
1721                 sb.append('\n');
1722             }
1723             return;
1724         }
1725         if (!ocs.onlyOneOrder(domains.size())) {
1726             InferenceRule ir = (InferenceRule) rules.get(0);
1727             List rest = rules.subList(1, rules.size());
1728             if (ir instanceof BDDInferenceRule && rulesToTrials.containsKey(ir)) {
1729              
1730                 BDDInferenceRule bddir = (BDDInferenceRule) ir;
1731                 if(TRACE > 1) {
1732                     out.println("Generating constraints for rule:\n" + ir.toString());
1733                     out.println("Total rule run time: " + bddir.totalTime);
1734                 }
1735                 OrderTranslator t = new MapBasedTranslator(bddir.variableToBDDDomain);
1736                 EpisodeCollection tc = new EpisodeCollection(bddir, 0);
1737                 for (int i = 0; i < 5; ++i) {
1738                     TrialGuess tg = tryNewGoodOrder(tc, new ArrayList(bddir.necessaryVariables), bddir, -2, null, true);
1739                     if (tg == null) break;
1740                     OrderConstraintSet ocs2 = new OrderConstraintSet(ocs);
1741                     Order bddOrder = t.translate(tg.order);
1742                     out.println("Adding constraints for: " + bddOrder);
1743                     Collection invalidConstraints = new LinkedList();
1744                     boolean worked = ocs2.constrain(bddOrder, invalidConstraints);
1745                     double score2 = tg.prediction.score * (bddir.totalTime+1) / 1000;
1746     
1747                     /*tc.addTrial(tg.order, null, 0);
1748                     if (!worked) 
1749                         out.println("Couldn't add constraints: " + invalidConstraints);
1750                     */
1751                   /*  printBestBDDOrders(sb, score + score2, domains, ocs2, rulesToTrials, rest); */
1752                     
1753                    if (worked) {
1754                         
1755                         printBestBDDOrders(sb, score + score2, domains, ocs2, rulesToTrials, rest);
1756                     }else{
1757                         out.println("Couldn't add constraints: " + invalidConstraints);
1758                     }
1759                     tc.addTrial(tg.order, null, 0, System.currentTimeMillis());
1760                 
1761                 }
1762             } else {
1763                 printBestBDDOrders(sb, score, domains, ocs, rulesToTrials, rest);
1764             }
1765         }
1766         
1767         //Only one order
1768     
1769         out.println("Can't add more constraints: " + ocs);
1770         out.println("Left over rules (" + rules.size() + ": " + rules);
1771         Order o = ocs.generateRandomOrder(domains);
1772         for (Iterator k = rules.iterator(); k.hasNext(); ) {
1773            InferenceRule ir = (InferenceRule) k.next();
1774            if(!(ir instanceof BDDInferenceRule)) continue;
1775             BDDInferenceRule bddir = (BDDInferenceRule) ir;
1776             Order o2;
1777             if (false) {
1778                 MultiMap d2v = new GenericMultiMap();
1779                 for (Iterator a = bddir.variableToBDDDomain.entrySet().iterator(); a.hasNext(); ) {
1780                     Map.Entry e = (Map.Entry) a.next();
1781                     d2v.add(e.getValue(), e.getKey());
1782                 }
1783                 o2 = new MapBasedTranslator(d2v).translate(o);
1784             } else {
1785                 Map d2v = new HashMap();
1786                 for (Iterator a = bddir.variableToBDDDomain.entrySet().iterator(); a.hasNext(); ) {
1787                     Map.Entry e = (Map.Entry) a.next();
1788                     d2v.put(e.getValue(), e.getKey());
1789                 }
1790                 o2 = new MapBasedTranslator(d2v).translate(o);
1791             }
1792             TrialGuess tg = tryNewGoodOrder(null, new ArrayList(bddir.necessaryVariables), bddir, -2, o2, true);
1793             score += tg.prediction.score * (bddir.totalTime+1) / 1000;
1794         }
1795         sb.append("Score "+format(score)+": "+o.toVarOrderString(null));
1796         sb.append('\n');
1797     }
1798     
1799     public Set getVisitedRules(){
1800         Set visitedRules = new HashSet();
1801         for(Iterator it = allTrials.iterator(); it.hasNext(); ){
1802             EpisodeCollection tc = (EpisodeCollection) it.next();
1803             visitedRules.add(tc.getRule(solver));
1804         }
1805     
1806         if(TRACE > 2) out.println("Visited Rules: " + visitedRules);
1807         return visitedRules;  
1808     }
1809     
1810     public void printBestBDDOrders() {
1811         MultiMap ruleToTrials = new GenericMultiMap();
1812         for (Iterator i = allTrials.iterator(); i.hasNext(); ) {
1813             EpisodeCollection tc = (EpisodeCollection) i.next();
1814             ruleToTrials.add(tc.getRule(solver), tc);
1815         }
1816         
1817         // Sort rules by their run time.
1818         SortedSet sortedRules = new TreeSet(new Comparator() {
1819             public int compare(Object o1, Object o2) {
1820                 if (o1 == o2) return 0;
1821                 if (o1 instanceof NumberingRule) return -1;
1822                 if (o2 instanceof NumberingRule) return 1;
1823                 BDDInferenceRule r1 = (BDDInferenceRule) o1;
1824                 BDDInferenceRule r2 = (BDDInferenceRule) o2;
1825                 long diff = r2.totalTime - r1.totalTime; //descending
1826                 //long diff = r1.totalTime - r2.totalTime;  //ascending 
1827                 if (diff != 0)
1828                     return (int) diff;
1829                 return r1.id - r2.id;
1830             }
1831         });
1832         sortedRules.addAll(filterRules(solver.rules));
1833         ArrayList list = new ArrayList(sortedRules);
1834         for (Iterator i = list.iterator (); i.hasNext(); ) {
1835             BDDInferenceRule rule = (BDDInferenceRule) i.next();
1836             System.out.println(bestOrders(rule, 5));
1837         }
1838         Collection domains = new FlattenedCollection(solver.getBDDDomains().values());
1839         out.println("BDD Domains: "+domains);
1840         OrderConstraintSet ocs = new OrderConstraintSet();
1841         StringBuffer sb = new StringBuffer();
1842      //   printBestBDDOrders(sb, 0, domains, ocs, ruleToTrials, list);
1843         myPrintBestBDDOrders(sb, domains, list);
1844         out.println(sb);
1845     }
1846     
1847     /***
1848      * Generate the k best orders for the given inference rule and
1849      * put the result in a string.
1850      * 
1851      * @param rule  inference rule
1852      * @param k  number of orders
1853      * @return  string result
1854      */
1855     public String bestOrders(BDDInferenceRule rule, int k) {
1856         List vars = new LinkedList(rule.getNecessaryVariables());
1857         Object[] arr = vars.toArray();
1858         Arrays.sort(arr, rule.new VarOrderComparator(solver.VARORDER));
1859         vars = Arrays.asList(arr);
1860         EpisodeCollection tc = new EpisodeCollection(rule, 0);
1861         StringBuffer sb = new StringBuffer();
1862         sb.append(rule.toString());
1863         sb.append(Strings.lineSep);
1864         for (int i = 1; i <= k; ++i) {
1865             TrialGuess tg = tryNewGoodOrder(tc, vars, rule, -2, null, true);
1866             if (tg == null) break;
1867             sb.append("    Order #").append(i).append(": ");
1868             sb.append(tg.toString());
1869         }
1870         return sb.toString();
1871     }
1872     
1873     static Collection filterRules(Collection rules){
1874         Collection filteredRules = new LinkedList();
1875         for(Iterator it = rules.iterator(); it.hasNext(); ){
1876             InferenceRule rule = (InferenceRule) it.next();
1877             if(rule instanceof BDDInferenceRule) filteredRules.add(rule);
1878         }
1879         return filteredRules;
1880     }
1881     public void printBestTrials() {
1882         MultiMap ruleToTrials = new GenericMultiMap();
1883         for (Iterator i = allTrials.iterator(); i.hasNext(); ) {
1884             EpisodeCollection tc = (EpisodeCollection) i.next();
1885             ruleToTrials.add(tc.getRule(solver), tc);
1886         }
1887         // Sort rules by their run time.
1888         SortedSet sortedRules = new TreeSet(new Comparator() {
1889             public int compare(Object o1, Object o2) {
1890                 if (o1 == o2) return 0;
1891                 BDDInferenceRule r1 = (BDDInferenceRule) o1;
1892                 BDDInferenceRule r2 = (BDDInferenceRule) o2;
1893                 
1894                 long diff = r2.totalTime - r1.totalTime;
1895                
1896                 if (diff != 0)
1897                     return (int) diff;
1898                 return r1.id - r2.id;
1899             }
1900         });
1901         sortedRules.addAll(ruleToTrials.keySet());
1902         
1903         for (Iterator i = sortedRules.iterator(); i.hasNext(); ) {
1904             BDDInferenceRule ir = (BDDInferenceRule) i.next();
1905             Map scoreboard = new HashMap();
1906             for (Iterator j = ruleToTrials.getValues(ir).iterator(); j.hasNext(); ) {
1907                 EpisodeCollection tc = (EpisodeCollection) j.next();
1908                 TrialInfo ti = tc.getMinimum();
1909                 if (ti == null || ti.isMax()) continue;
1910                 long[] score = (long[]) scoreboard.get(ti.order);
1911                 if (score == null) scoreboard.put(ti.order, score = new long[2]);
1912                 score[0]++;
1913                 score[1] += ti.cost;
1914             }
1915             
1916             if (scoreboard.isEmpty()) continue;
1917             
1918             SortedSet sortedTrials = new TreeSet(new Comparator() {
1919                 public int compare(Object o1, Object o2) {
1920                     long[] counts1 = (long[]) ((Map.Entry) o1).getValue();
1921                     long[] counts2 = (long[]) ((Map.Entry) o2).getValue();
1922                     long diff = counts2[0] - counts1[0];
1923                     if (diff != 0)
1924                         return (int) diff;
1925                     diff = counts2[1] - counts1[1];
1926                     if (diff != 0)
1927                         return (int) diff;
1928                     Order order1 = (Order) ((Map.Entry) o1).getKey();
1929                     Order order2 = (Order) ((Map.Entry) o2).getKey();
1930                     return order1.compareTo(order2);
1931                 }
1932             });
1933             sortedTrials.addAll(scoreboard.entrySet());
1934             
1935             out.println("For rule"+ir.id+": "+ir);
1936             for (Iterator it = sortedTrials.iterator(); it.hasNext();) {
1937                 Map.Entry entry = (Map.Entry) it.next();
1938                 Order order = (Order) entry.getKey();
1939                 long[] counts = (long[]) entry.getValue();
1940                 double aveTime = (double)counts[1] / (double) counts[0];
1941                 String bddString = order.toVarOrderString(ir.variableToBDDDomain);
1942                 out.println(order + " won " + counts[0] + " time(s), average winning time of "+format(aveTime)+" ms");
1943                 out.println("   BDD order: "+bddString);
1944             }
1945             out.println();
1946         }
1947         
1948     }
1949     
1950     public void printTrialsDistro() {
1951         printTrialsDistro(allTrials, solver);
1952     }
1953 
1954     public static void printTrialsDistro(Collection trials, Solver solver) {
1955         Map orderToCounts = new HashMap();
1956         final int numRules = solver.getRules().size();
1957         int total = 0, distinct = 0;
1958         for (Iterator it = trials.iterator(); it.hasNext();) {
1959             EpisodeCollection tc = (EpisodeCollection) it.next();
1960             Assert._assert(tc != null);
1961             for (Iterator jt = tc.getTrials().iterator(); jt.hasNext();) {
1962                 TrialInfo ti = (TrialInfo) jt.next();
1963                 Order order = ti.order;
1964                 int[] counts = (int[]) orderToCounts.get(order);
1965                 if (counts == null) {
1966                     counts = new int[numRules + 1];
1967                     orderToCounts.put(order, counts);
1968                     ++distinct;
1969                 }
1970                 ++counts[tc.getRule(solver).id];
1971                 //one extra int at the end to count the total number of trials
1972                 ++counts[numRules];
1973             }
1974             total += tc.getNumTrials();
1975         }
1976 
1977         SortedSet sortedTrials = new TreeSet(new Comparator() {
1978             public int compare(Object o1, Object o2) {
1979                 int[] counts1 = (int[]) ((Map.Entry) o1).getValue();
1980                 int[] counts2 = (int[]) ((Map.Entry) o2).getValue();
1981                 int diff = counts2[numRules] - counts1[numRules];
1982                 if (diff != 0)
1983                     return diff;
1984                 Order order1 = (Order) ((Map.Entry) o1).getKey();
1985                 Order order2 = (Order) ((Map.Entry) o2).getKey();
1986                 return order1.compareTo(order2);
1987             }
1988         });
1989 
1990         sortedTrials.addAll(orderToCounts.entrySet());
1991         out.println(total + " trials  of " + distinct + " distinct orders");
1992         out.println("tried Orders: ");
1993         for (Iterator it = sortedTrials.iterator(); it.hasNext();) {
1994             Map.Entry entry = (Map.Entry) it.next();
1995             Order order = (Order) entry.getKey();
1996             int[] counts = (int[]) entry.getValue();
1997             out.println(order + " tried a total of " + counts[numRules] + " time(s) :");
1998             for (int i = 0; i < counts.length - 1; ++i) {
1999                 int count = counts[i];
2000                 if (count != 0) {
2001                     out.println("    " + count + " time(s) on \n    " + solver.getRule(i));
2002                 }
2003             }
2004             out.println();
2005         }
2006     }
2007 
2008     public static void main(String[] args) throws Exception {
2009         String inputFilename = SystemProperties.getProperty("datalog");
2010         if (args.length > 0) inputFilename = args[0];
2011         if (inputFilename == null) {
2012             return;
2013         }
2014         String solverName = SystemProperties.getProperty("solver", "net.sf.bddbddb.BDDSolver");
2015         Solver s;
2016         s = (Solver) Class.forName(solverName).newInstance();
2017         s.load(inputFilename);
2018 
2019         FindBestDomainOrder dis = ((BDDSolver) s).fbo;
2020         //dis.loadTrials("trials.xml");
2021         //dis.dump();
2022 /*
2023         for (Iterator i = s.rules.iterator(); i.hasNext();) {
2024             InferenceRule ir = (InferenceRule) i.next();
2025             if (ir.necessaryVariables == null) continue;
2026             out.println("Computing for rule " + ir);
2027 
2028             List allVars = new LinkedList();
2029             allVars.addAll(ir.necessaryVariables);
2030             out.println("Variables = " + allVars);
2031 
2032             TrialGuess guess = dis.tryNewGoodOrder(null, allVars, ir, false);
2033 
2034             out.println("Resulting guess: "+guess);
2035         }
2036 */
2037         //printTrialsDistro(dis.allTrials, s);
2038         //dis.printBestTrials();
2039         dis.printBestBDDOrders();
2040     }
2041 
2042     
2043     /***
2044      * Run the find best domain order on the given inputs.
2045      * 
2046      * @param bdd  BDD factory
2047      * @param b1   first input to relprod
2048      * @param b2   second input to relprod
2049      * @param b3   third input to relprod
2050      * @param r1   first rule term
2051      * @param r2   second rule term
2052      * @param vars1  variables of b1
2053      * @param vars2  variables of b2
2054      */
2055     static void findBestDomainOrder(BDDSolver solver, BDDInferenceRule rule, int opNum, BDDFactory bdd, BDD b1, BDD b2, BDDVarSet b3, RuleTerm r1, RuleTerm r2, Collection vars1, Collection vars2) {
2056         Set allVarSet = new HashSet(vars1); allVarSet.addAll(vars2);
2057         allVarSet.removeAll(rule.unnecessaryVariables);
2058         Object[] a = allVarSet.toArray();
2059         // Sort the variables by domain so that we will first try orders that are close
2060         // to the default one.
2061         Arrays.sort(a, rule.new VarOrderComparator(solver.VARORDER));
2062         List allVars = Arrays.asList(a);
2063         
2064         FindBestDomainOrder fbdo = solver.fbo;
2065         if (!fbdo.hasOrdersToTry(allVars, rule)) {
2066             out.println("No more orders to try, skipping find best order for "+vars1+","+vars2);
2067             return;
2068         }
2069         out.println("Finding best order for "+vars1+","+vars2);
2070         long time = System.currentTimeMillis();
2071         Episode ep = fbdo.getNewEpisode(rule, opNum, time, true);
2072         FindBestOrder fbo = new FindBestOrder(solver.BDDNODES, solver.BDDCACHE, solver.BDDNODES / 2, Long.MAX_VALUE, 5000);
2073         try {
2074             fbo.init(b1, b2, b3, BDDFactory.and);
2075         } catch (IOException x) {
2076             solver.err.println("IO Exception occurred: " + x);
2077             fbo.cleanup();
2078             return;
2079         }
2080         out.println("Time to initialize FindBestOrder: "+(System.currentTimeMillis()-time));
2081         int count = BDDInferenceRule.MAX_FBO_TRIALS;
2082         boolean first = true;
2083         long bestTime = Long.MAX_VALUE;
2084         while (--count >= 0) {
2085             //Order o = fbdo.tryNewGoodOrder(tc, allVars, t);
2086             TrialGuess guess = fbdo.tryNewGoodOrder(ep, allVars, rule,opNum, first);
2087             if (guess == null || guess.order == null) break;
2088             Order o = guess.order;
2089             String vOrder = o.toVarOrderString(rule.variableToBDDDomain);
2090             out.println("Trying order "+vOrder);
2091             vOrder = solver.fixVarOrder(vOrder, false);
2092             out.println("Complete order "+vOrder);
2093             time = fbo.tryOrder(true, vOrder);
2094             time = Math.min(time, BDDInferenceRule.LONG_TIME);
2095             bestTime = Math.min(time, bestTime);
2096             fbdo.addTrial(rule, allVars,ep, o,guess.prediction, time, System.currentTimeMillis());
2097             
2098             if (time >= BDDInferenceRule.LONG_TIME)
2099                 fbdo.neverTryAgain(rule, o);
2100             first = false;
2101         }
2102         fbo.cleanup();
2103         
2104         fbdo.incorporateTrial(ep);
2105         
2106         XMLFactory.dumpXML("fbo.xml", fbdo.toXMLElement());
2107         XMLFactory.dumpXML(solver.TRIALFILE, fbdo.trialsToXMLElement());
2108     }
2109 
2110 
2111 }