1
2
3
4 package net.sf.bddbddb;
5
6 import java.util.ArrayList;
7 import java.util.Arrays;
8 import java.util.Collection;
9 import java.util.Collections;
10 import java.util.Comparator;
11 import java.util.Enumeration;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Iterator;
15 import java.util.LinkedHashSet;
16 import java.util.LinkedList;
17 import java.util.List;
18 import java.util.Map;
19 import java.util.Set;
20 import java.util.SortedSet;
21 import java.util.TreeSet;
22 import java.io.BufferedWriter;
23 import java.io.File;
24 import java.io.FileOutputStream;
25 import java.io.FileWriter;
26 import java.io.IOException;
27 import java.io.PrintStream;
28 import java.net.URL;
29 import java.text.NumberFormat;
30 import java.text.SimpleDateFormat;
31 import jwutil.collections.FlattenedCollection;
32 import jwutil.collections.GenericMultiMap;
33 import jwutil.collections.MultiMap;
34 import jwutil.collections.Pair;
35 import jwutil.io.SystemProperties;
36 import jwutil.strings.Strings;
37 import jwutil.util.Assert;
38 import net.sf.bddbddb.order.AttribToDomainTranslator;
39 import net.sf.bddbddb.order.CandidateSampler;
40 import net.sf.bddbddb.order.ConstraintInfo;
41 import net.sf.bddbddb.order.Discretization;
42 import net.sf.bddbddb.order.EpisodeCollection;
43 import net.sf.bddbddb.order.MapBasedTranslator;
44 import net.sf.bddbddb.order.MyId3;
45 import net.sf.bddbddb.order.Order;
46 import net.sf.bddbddb.order.OrderConstraint;
47 import net.sf.bddbddb.order.OrderConstraintSet;
48 import net.sf.bddbddb.order.OrderTranslator;
49 import net.sf.bddbddb.order.Queue;
50 import net.sf.bddbddb.order.StackQueue;
51 import net.sf.bddbddb.order.TrialDataRepository;
52 import net.sf.bddbddb.order.TrialGuess;
53 import net.sf.bddbddb.order.TrialInfo;
54 import net.sf.bddbddb.order.TrialInstance;
55 import net.sf.bddbddb.order.TrialInstances;
56 import net.sf.bddbddb.order.TrialPrediction;
57 import net.sf.bddbddb.order.VarToAttribTranslator;
58 import net.sf.bddbddb.order.WekaInterface;
59 import net.sf.bddbddb.order.CandidateSampler.UncertaintySampler;
60 import net.sf.bddbddb.order.EpisodeCollection.Episode;
61 import net.sf.bddbddb.order.TrialDataRepository.TrialDataGroup;
62 import net.sf.bddbddb.order.WekaInterface.OrderAttribute;
63 import net.sf.bddbddb.order.WekaInterface.OrderInstance;
64 import net.sf.javabdd.BDD;
65 import net.sf.javabdd.BDDFactory;
66 import net.sf.javabdd.BDDVarSet;
67 import net.sf.javabdd.FindBestOrder;
68 import org.jdom.Document;
69 import org.jdom.Element;
70 import org.jdom.input.SAXBuilder;
71 import weka.classifiers.Classifier;
72 import weka.core.Instance;
73 import weka.core.Instances;
74
75 /***
76 * FindBestDomainOrder
77 *
78 * Design:
79 *
80 * TrialInfo : order, cost
81 * EpisodeCollection : collection of TrialInfo, best time, worst time
82 * Constraint: a<b or axb or a_b
83 * Order : collection of ordering constraints
84 * ConstraintInfo : map from a single constraint to score/confidence
85 * OrderInfo : order, predicted score and confidence
86 *
87 * Maps:
88 * Relation -> ConstraintInfo collection
89 * Rule -> ConstraintInfo collection
90 * EpisodeCollection -> ConstraintInfo collection
91 *
92 * Algorithm to compute best order:
93 * - Combine and sort single constraints from relation, rule, trials so far.
94 * Sort by score*confidence (?)
95 * Combine and adjust opposite constraints (?)
96 * Sort by difference between opposite constraints (?)
97 * - Do an A* search.
98 * Keep track of the current score/confidence as we add constraints.
99 * As we add new constraints, flag conflicting ones.
100 * Predict final score by combining top n non-conflicting constraints (?)
101 * If prediction is worse than current best score, return immediately.
102 *
103 * @author John Whaley
104 * @version $Id: FindBestDomainOrder.java 645 2006-07-17 05:20:20Z joewhaley $
105 */
106 public class FindBestDomainOrder {
107
108
109 public static int TRACE = 2;
110
111 public static PrintStream out;
112 public static PrintStream out_t;
113
114 public static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyMMdd-HHmmss");
115
116 /***
117 * Link back to the solver.
118 */
119 BDDSolver solver;
120
121 /***
122 * Collection of all EpisodeCollections that have been done so far, including
123 * ones that have been loaded from disk.
124 */
125 Collection allTrials;
126
127 TrialDataRepository dataRepository;
128
129 /***
130 * Whether we should keep track of per-rule constraints, in addition to global
131 * constraints.
132 */
133 public static boolean PER_RULE_CONSTRAINTS = true;
134
135 public static boolean DUMP_CLASSIFIER_INFO = true;
136
137
138 /***
139 * Info collection for each of the constraints.
140 */
141 ConstraintInfoCollection constraintInfo;
142
143 /***
144 * Construct a new empty FindBestDomainOrder object.
145 */
146 public FindBestDomainOrder(Solver s) {
147 constraintInfo = new ConstraintInfoCollection(s);
148 allTrials = new LinkedList();
149 if (s instanceof BDDSolver){
150 solver = (BDDSolver) s;
151 dataRepository = new TrialDataRepository(allTrials, solver);
152 }
153 out = solver.out;
154 }
155
156
157 /***
158 * Construct a new FindBestDomainOrder object with the given info.
159 */
160 FindBestDomainOrder(ConstraintInfoCollection c) {
161 constraintInfo = c;
162 allTrials = new LinkedList();
163 if (c.solver instanceof BDDSolver)
164 solver = (BDDSolver) c.solver;
165 out = solver.out;
166 }
167
168 /***
169 * Load and incorporate trials from the given XML file.
170 *
171 * @param filename filename
172 */
173 void loadTrials(String filename) {
174 out.println("Trials filename=" + filename);
175 File file = new File(filename);
176 if (file.exists()) {
177 try {
178 URL url = file.toURL();
179 SAXBuilder builder = new SAXBuilder();
180 Document doc = builder.build(url);
181 XMLFactory f = new XMLFactory(solver);
182 Element e = doc.getRootElement();
183 List list = (List) f.fromXML(e);
184 if (TRACE > 0) {
185 out.println("Loaded " + list.size() + " trial collections from file.");
186 if (TRACE > 2) {
187 for (Iterator i = list.iterator(); i.hasNext();) {
188 out.println("Loaded from file: " + i.next());
189 }
190 }
191 }
192 allTrials.addAll(list);
193 } catch (Exception e) {
194 solver.err.println("Error occurred loading " + filename + ": " + e);
195 e.printStackTrace();
196 }
197 }
198 incorporateTrials();
199 }
200
201 /***
202 * Incorporate all of the trials in allTrials.
203 */
204 void incorporateTrials() {
205 for (Iterator i = allTrials.iterator(); i.hasNext();) {
206 EpisodeCollection tc = (EpisodeCollection) i.next();
207 constraintInfo.addTrials(tc);
208 }
209 }
210
211 void incorporateTrial(Episode ep) {
212 constraintInfo.addTrials(ep.getEpisodeCollection());
213
214 if (TRACE > 2)
215 dump();
216 }
217
218 /***
219 * Dump the collected order info for rules and relations to standard output.
220 */
221 public void dump() {
222 SortedSet set = new TreeSet(new Comparator() {
223 public int compare(Object o1, Object o2) {
224 ConstraintInfo info1 = (ConstraintInfo) ((Map.Entry) o1).getValue();
225 ConstraintInfo info2 = (ConstraintInfo) ((Map.Entry) o2).getValue();
226 return info1.compareTo(info2);
227 }
228 });
229 set.addAll(constraintInfo.infos.entrySet());
230 for (Iterator i = set.iterator(); i.hasNext();) {
231 Map.Entry e = (Map.Entry) i.next();
232 OrderConstraint ir = (OrderConstraint) e.getKey();
233 out.println("Order feature: " + ir);
234 ConstraintInfo info = (ConstraintInfo) e.getValue();
235 info.dump();
236 }
237 }
238
239 public int getNumberOfTrials() {
240 int sum = 0;
241 for (Iterator i = allTrials.iterator(); i.hasNext();) {
242 EpisodeCollection ec = (EpisodeCollection) i.next();
243 sum += ec.getNumTrials();
244 }
245 return sum;
246 }
247
248 /***
249 * Starts a new trial collection and returns it.
250 *
251 * @param rule inference rule of trial collection
252 * @param opNumber operation number of trial collection
253 * @param timeStamp time of trial collection
254 * @param newCollection whether to always return a new collection
255 * @return new trial collection
256 */
257 public Episode getNewEpisode(BDDInferenceRule rule, int opNumber, long timeStamp, boolean newCollection) {
258 EpisodeCollection c = newCollection ? findEpisodeCollection(rule, opNumber) : null;
259
260 if(c == null){
261 c = new EpisodeCollection(rule, opNumber);
262 allTrials.add(c);
263 }
264 return c.startNewEpisode(timeStamp);
265 }
266
267
268 public EpisodeCollection findEpisodeCollection(BDDInferenceRule rule, int opNumber){
269 for(Iterator it = allTrials.iterator(); it.hasNext(); ){
270 EpisodeCollection tc = (EpisodeCollection) it.next();
271 if(tc.getRule(solver) == rule && tc.getUpdateCount() == rule.updateCount && tc.getOpNumber() == opNumber){
272 if(TRACE > 1) out.println("Found a tc for: rule " + rule.id + " on update " + rule.updateCount + " and op " + opNumber);
273 return tc;
274 }
275 }
276 return null;
277 }
278
279 /***
280 * Calculated information about an order. This consists of a score
281 * and an estimated information gain.
282 *
283 * @author John Whaley
284 * @version $Id: FindBestDomainOrder.java 645 2006-07-17 05:20:20Z joewhaley $
285
286 */
287 public static class OrderInfo implements Comparable {
288
289 /***
290 * The order this information is about.
291 */
292 Order order;
293
294 /***
295 * A measure of how good this order is.
296 */
297 double score;
298
299 /***
300 * A measure of the expected information gain from running this order.
301 */
302 double infoGain;
303
304 /***
305 * Construct a new OrderInfo.
306 *
307 * @param o order
308 * @param s score
309 * @param c info gain
310 */
311 public OrderInfo(Order o, double s, double c) {
312 this.order = o;
313 this.score = s;
314 this.infoGain = c;
315 }
316
317 /***
318 * Construct a new OrderInfo that is a clone of another.
319 *
320 * @param that other OrderInfo to clone from
321 */
322 public OrderInfo(OrderInfo that) {
323 this.order = that.order;
324 this.score = that.score;
325 this.infoGain = that.infoGain;
326 }
327
328
329
330
331 public String toString() {
332 return order + ": score " + format(score) + " info gain " + format(infoGain);
333 }
334
335
336
337
338 public int compareTo(Object arg0) {
339 return compareTo((OrderInfo) arg0);
340 }
341
342 /***
343 * Comparison operator for OrderInfo objects. Score is most important, followed
344 * by info gain. If both are equal, we compare the order lexigraphically.
345 *
346 * @param that OrderInfo to compare to
347 * @return -1, 0, or 1 if this OrderInfo is less than, equal to, or greater than the other
348 */
349 public int compareTo(OrderInfo that) {
350 if (this == that) return 0;
351 int result = signum(that.score - this.score);
352 if (result == 0) {
353 result = (int) signum(this.infoGain - that.infoGain);
354 if (result == 0) {
355 result = this.order.compareTo(that.order);
356 }
357 }
358 return result;
359 }
360
361 /***
362 * Returns this OrderInfo as an XML element.
363 *
364 * @return XML element
365 */
366 public Element toXMLElement() {
367 Element dis = new Element("orderInfo");
368 dis.setAttribute("order", order.toString());
369 dis.setAttribute("score", Double.toString(score));
370 dis.setAttribute("infoGain", Double.toString(infoGain));
371 return dis;
372 }
373
374 public static OrderInfo fromXMLElement(Element e, Map nameToVar) {
375 Order o = Order.parse(e.getAttributeValue("order"), nameToVar);
376 double s = Double.parseDouble(e.getAttributeValue("score"));
377 double c = Double.parseDouble(e.getAttributeValue("infoGain"));
378 return new OrderInfo(o, s, c);
379 }
380 }
381
382 /***
383 * Generate all orders of a given list of variables.
384 *
385 * @param vars list of variables
386 * @return list of all orders of those variables
387 */
388 public static List
389 if (vars.size() == 0) return null;
390 LinkedList result = new LinkedList();
391 if (vars.size() == 1) {
392 result.add(new Order(vars));
393 return result;
394 }
395 Object car = vars.get(0);
396 List recurse = generateAllOrders(vars.subList(1, vars.size()));
397 for (Iterator i = recurse.iterator(); i.hasNext();) {
398 Order order = (Order) i.next();
399 for (int j = 0; j <= order.size(); ++j) {
400 Order myOrder = new Order(order);
401 myOrder.add(j, car);
402 result.add(myOrder);
403 }
404 }
405 for (Iterator i = recurse.iterator(); i.hasNext();) {
406 Order order = (Order) i.next();
407 for (int j = 0; j < order.size(); ++j) {
408 Order myOrder = new Order(order);
409 Object o = myOrder.get(j);
410 List c = new LinkedList();
411 c.add(car);
412 if (o instanceof Collection) {
413 c.addAll((Collection) o);
414 } else {
415 c.add(o);
416 }
417 myOrder.set(j, c);
418 result.add(myOrder);
419 }
420 }
421 return result;
422 }
423
424 transient static NumberFormat nf;
425
426 /***
427 * Format a double in a nice way.
428 *
429 * @param d double
430 * @return string representation
431 */
432 public static String format(double d) {
433 if (nf == null) {
434 nf = NumberFormat.getNumberInstance();
435
436 nf.setMaximumFractionDigits(3);
437 }
438 if (d == Double.MAX_VALUE) return "max";
439 return nf.format(d);
440 }
441
442 public static String format(double d, int numFracDigits) {
443 if (nf == null) {
444 nf = NumberFormat.getNumberInstance();
445
446 nf.setMaximumFractionDigits(numFracDigits);
447 }
448 if (d == Double.MAX_VALUE) return "max";
449 return nf.format(d);
450 }
451
452
453 public static int signum(long d) {
454 if (d < 0) return -1;
455 if (d > 0) return 1;
456 return 0;
457 }
458
459
460 public static int signum(double d) {
461 if (d < 0) return -1;
462 if (d > 0) return 1;
463 return 0;
464 }
465
466 public static class ConstraintInfoCollection {
467
468 Solver solver;
469
470 /***
471 * Map from orders to their info.
472 */
473 Map
474
475 public ConstraintInfoCollection(Solver s) {
476 this.solver = s;
477 this.infos = new HashMap();
478 }
479
480 public ConstraintInfo getInfo(OrderConstraint c) {
481 return (ConstraintInfo) infos.get(c);
482 }
483
484 public ConstraintInfo getOrCreateInfo(OrderConstraint c) {
485 ConstraintInfo ci = (ConstraintInfo) infos.get(c);
486 if (ci == null) infos.put(c, ci = new ConstraintInfo(c));
487 return ci;
488 }
489
490 private void addTrials(EpisodeCollection tc, OrderTranslator trans) {
491 MultiMap c2Trials = new GenericMultiMap();
492
493 for (Iterator i = tc.getTrials().iterator(); i.hasNext();) {
494 TrialInfo ti = (TrialInfo) i.next();
495 Order o = ti.order;
496 if (ti.cost >= BDDInferenceRule.LONG_TIME)
497 ((BDDSolver) solver).fbo.neverTryAgain(tc.getRule(solver), o);
498 if (trans != null) o = trans.translate(o);
499 Collection ocs = o.getConstraints();
500 for (Iterator j = ocs.iterator(); j.hasNext();) {
501 OrderConstraint oc = (OrderConstraint) j.next();
502 c2Trials.add(oc, ti);
503 }
504
505 }
506
507 for (Iterator i = c2Trials.keySet().iterator(); i.hasNext();) {
508 OrderConstraint oc = (OrderConstraint) i.next();
509 ConstraintInfo info = getOrCreateInfo(oc);
510 info.registerTrials(c2Trials.getValues(oc));
511 }
512
513 }
514
515 public void addTrials(EpisodeCollection tc) {
516 InferenceRule ir = tc.getRule(solver);
517 OrderTranslator varToAttrib = new VarToAttribTranslator(ir);
518 addTrials(tc, varToAttrib);
519 if (PER_RULE_CONSTRAINTS) {
520 addTrials(tc, null);
521 }
522 if (TRACE > 2) {
523 out.println("Added trial collection: " + tc);
524 }
525 }
526
527 public OrderInfo predict(Order o, OrderTranslator trans) {
528 if (TRACE > 2) out.println("Predicting order "+o);
529 if (trans != null) o = trans.translate(o);
530 if (TRACE > 2) out.println("Translated into order "+o);
531 double score = 0.;
532 int numTrialCollections = 0, numTrials = 0;
533 Collection cinfos = new LinkedList();
534 for (Iterator i = o.getConstraints().iterator(); i.hasNext();) {
535 OrderConstraint c = (OrderConstraint) i.next();
536 ConstraintInfo ci = getInfo(c);
537 if (ci == null || ci.getNumberOfTrials() == 0) continue;
538 cinfos.add(ci);
539 score += ci.getWeightedMean();
540 numTrialCollections++;
541 numTrials += ci.getNumberOfTrials();
542 }
543 if (numTrialCollections == 0)
544 score = 0.;
545 else
546 score = score / numTrialCollections;
547 double infoGain = ConstraintInfo.getVariance(cinfos) / numTrials;
548 if (TRACE > 2) out.println("Prediction for "+o+": score "+format(score)+" infogain "+format(infoGain));
549 return new OrderInfo(o, score, infoGain);
550 }
551
552 }
553
554 public OrderInfo predict(Order o, OrderTranslator trans) {
555 return constraintInfo.predict(o, trans);
556 }
557
558 /***
559 * Returns this FindBestDomainOrder as an XML element.
560 *
561 * @return XML element
562 */
563 public Element toXMLElement() {
564 Element constraintInfoCollection = new Element("constraintInfoCollection");
565 for (Iterator i = constraintInfo.infos.entrySet().iterator(); i.hasNext();) {
566 Map.Entry e = (Map.Entry) i.next();
567 OrderConstraint oc = (OrderConstraint) e.getKey();
568 ConstraintInfo c = (ConstraintInfo) e.getValue();
569 Element constraintInfo = c.toXMLElement(solver);
570 constraintInfoCollection.addContent(constraintInfo);
571 }
572
573 Element fbo = new Element("findBestOrder");
574 fbo.addContent(constraintInfoCollection);
575
576 return fbo;
577 }
578
579 /***
580 * Returns the set of all trials performed so far as an XML element.
581 *
582 * @return XML element
583 */
584 public Element trialsToXMLElement() {
585 Element trialCollections = new Element("episodeCollections");
586 if (solver.inputFilename != null)
587 trialCollections.setAttribute("datalog", solver.inputFilename);
588 for (Iterator i = allTrials.iterator(); i.hasNext();) {
589 EpisodeCollection c = (EpisodeCollection) i.next();
590 trialCollections.addContent(c.toXMLElement());
591 }
592 return trialCollections;
593 }
594
595 /***
596 */
597 public boolean hasOrdersToTry(List allVars, BDDInferenceRule ir) {
598
599 int nTrials = getNumberOfTrials();
600 if (nTrials != ir.lastTrialNum) {
601 ir.lastTrialNum = nTrials;
602 TrialGuess g = this.tryNewGoodOrder(null, allVars, ir, -2, null, false);
603 return g != null;
604 } else {
605 return false;
606 }
607 }
608
609
610 public static final int compare(double d1, double d2) {
611 if (d1 < d2)
612 return -1;
613 if (d1 > d2)
614 return 1;
615
616 long thisBits = Double.doubleToLongBits(d1);
617 long anotherBits = Double.doubleToLongBits(d2);
618
619 return (thisBits == anotherBits ? 0 :
620 (thisBits < anotherBits ? -1 :
621 1));
622 }
623
624
625
626 public final static int WEIGHT_WINDOW_SIZE = Integer.MAX_VALUE;
627
628 public final static double DECAY_FACTOR = -.1;
629
630 public double computeWeight(int type, TrialInstances instances) {
631 int numTrials = 0;
632 double total = 0;
633 int losses = 0;
634 double weight = 1;
635 for (int i = instances.numInstances() - 1; i >= 0 && numTrials < WEIGHT_WINDOW_SIZE; --i) {
636 TrialInstance instance = (TrialInstance) instances.instance(i);
637 double trueCost = instance.getCost();
638
639 TrialPrediction pred = instance.getTrialInfo().pred;
640 double predCost = pred.predictions[type][TrialPrediction.LOW];
641 double dev = pred.predictions[type][TrialPrediction.HIGH];
642
643 if(predCost == -1) continue;
644 double trialWeight = Math.exp(DECAY_FACTOR * numTrials);
645 if (trueCost < predCost - dev || trueCost > predCost + dev) {
646 losses += trialWeight;
647 }
648 total += trialWeight;
649 ++numTrials;
650 }
651 if (numTrials != 0) {
652 weight = 1 - losses / (double) total;
653 }
654 return weight;
655 }
656
657 public void adjustWeights(TrialInstances vData, TrialInstances aData, TrialInstances dData) {
658 if (vData != null)
659 varClassWeight = computeWeight(TrialPrediction.VARIABLE, vData);
660 if (aData != null)
661 attrClassWeight = computeWeight(TrialPrediction.ATTRIBUTE, aData);
662 if (dData != null)
663 domClassWeight = computeWeight(TrialPrediction.DOMAIN, dData);
664 }
665
666 public static int NUM_CV_FOLDS = 10;
667
668 /***
669 * @param data
670 * @param cClassName
671 * @return Cross validation with number of folds as set by NUM_CV_FOLDS;
672 */
673 public double constFoldCV(Instances data, String cClassName) {
674 return WekaInterface.cvError(NUM_CV_FOLDS, data, cClassName);
675 }
676
677
678 public static boolean DISCRETIZE1 = true;
679 public static boolean DISCRETIZE2 = true;
680 public static boolean DISCRETIZE3 = true;
681 public static String CLASSIFIER1 = "net.sf.bddbddb.order.MyId3";
682 public static String CLASSIFIER2 = "net.sf.bddbddb.order.MyId3";
683 public static String CLASSIFIER3 = "net.sf.bddbddb.order.MyId3";
684
685
686 public void neverTryAgain(InferenceRule ir, Order o) {
687 if (true) {
688 if (TRACE > 2) out.println("For rule"+ir.id+", never trying order "+o+" again.");
689 neverAgain.add(ir, o);
690 }
691 }
692
693 MultiMap neverAgain = new GenericMultiMap();
694
695 public double varClassWeight = 1;
696 public double attrClassWeight = 1;
697 public double domClassWeight = 1;
698 public static int DOMAIN_THRESHOLD = 1000;
699 public static int NO_CLASS = -1;
700 public static int NO_CLASS_SCORE = -1;
701 public boolean PROBABILITY = false;
702 void dumpClassifierInfo(String name, Classifier c, Instances data) {
703 BufferedWriter w = null;
704 try {
705 w = new BufferedWriter(new FileWriter(name));
706 w.write("Classifier \"name\":\n");
707 w.write("Attributes: \n");
708 for(Enumeration e = data.enumerateAttributes(); e.hasMoreElements(); ){
709 w.write(e.nextElement() +"\n");
710 }
711 w.write("\n");
712 w.write("Based on data from "+data.numInstances()+" instances:\n");
713 for (Enumeration e = data.enumerateInstances(); e.hasMoreElements(); ) {
714 Instance i = (Instance) e.nextElement();
715
716 if (i instanceof TrialInstance) {
717 TrialInstance ti = (TrialInstance) i;
718 InferenceRule ir = ti.ti.getCollection().getRule(solver);
719 w.write(" "+ti.ti.getCollection().name+" "+ti.getOrder());
720 if (!ti.getOrder().equals(ti.ti.order))
721 w.write(" ("+ti.ti.order+")");
722 if (ti.isMaxTime()) {
723 w.write(" MAX TIME\n");
724 } else {
725 w.write(" "+format(ti.getCost())+" ("+ti.ti.cost+" ms)\n");
726 }
727 } else {
728 w.write(" "+i+"\n");
729 }
730 }
731 w.write(c.toString());
732 w.write("\n");
733 } catch (IOException x) {
734 solver.err.println("IO Exception occurred writing \""+name+"\": "+x);
735 } finally {
736 if (w != null) try { w.close(); } catch (IOException _) { }
737 }
738 }
739
740 void dumpTrialGuessInfo(String name) {
741 BufferedWriter w = null;
742 try {
743 w = new BufferedWriter(new FileWriter(name, true));
744 w.write("Classifier \"name\":\n");
745 w.write("\n");
746 } catch (IOException x) {
747 solver.err.println("IO Exception occurred writing \""+name+"\": "+x);
748 } finally {
749 if (w != null) try { w.close(); } catch (IOException _) { }
750 }
751 }
752
753
754 private void addTrial(InferenceRule rule, List variables, Episode ep, Order o, TrialPrediction prediction, long time, long timestamp) {
755
756 TrialInfo info = new TrialInfo(o,prediction,time,ep, timestamp);
757
758
759 ep.addTrial(info);
760 dataRepository.addTrial(rule, variables, info);
761 }
762
763 public static int INITIAL_VAR_SET = 10;
764 public static int INITIAL_ATTRIB_SET = 16;
765 public static int INITIAL_DOM_SET = 10;
766 public TrialGuess tryNewGoodOrder(Episode ep, List allVars, InferenceRule ir, int opNum,
767 boolean returnBest) {
768 return tryNewGoodOrder(ep.getEpisodeCollection(), allVars, ir, opNum, null, returnBest);
769 }
770
771 public TrialGuess tryNewGoodOrder(EpisodeCollection ec, List allVars, InferenceRule ir, int opNum,
772 Order chosenOne,
773 boolean returnBest) {
774
775 out.println("Variables: " + allVars);
776 TrialDataGroup vDataGroup = this.dataRepository.getVariableDataGroup(ir, allVars);
777 TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(ir,allVars);
778 TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(ir,allVars);
779
780
781 TrialInstances vData, aData, dData;
782 vData = vDataGroup.getTrialInstances();
783 aData = aDataGroup.getTrialInstances();
784 dData = dDataGroup.getTrialInstances();
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801 adjustWeights(vData, aData, dData);
802 Discretization vDis = null, aDis = null, dDis = null;
803
804
805
806
807
808
809
810 vDis = vDataGroup.discretize(.5);
811 aDis = aDataGroup.discretize(.25);
812 dDis = dDataGroup.threshold(DOMAIN_THRESHOLD);
813
814 long vCTime = System.currentTimeMillis();
815 double vConstCV = -1;
816 vCTime = System.currentTimeMillis() - vCTime;
817
818 long aCTime = System.currentTimeMillis();
819 double aConstCV = -1;
820 aCTime = System.currentTimeMillis() - aCTime;
821
822 long dCTime = System.currentTimeMillis();
823 double dConstCV = -1;
824 dCTime = System.currentTimeMillis() - dCTime;
825
826 long vLTime = System.currentTimeMillis();
827 double vLeaveCV = -1;
828 vLTime = System.currentTimeMillis() - vLTime;
829
830 long aLTime = System.currentTimeMillis();
831 double aLeaveCV = -1;
832 aLTime = System.currentTimeMillis() - aLTime;
833
834 long dLTime = System.currentTimeMillis();
835 double dLeaveCV = -1;
836 dLTime = System.currentTimeMillis() - dLTime;
837
838 if (TRACE > 1) {
839 out.println(" Var data points: " + vData.numInstances());
840
841
842 out.println(" Var Classifier Weight: " + varClassWeight);
843
844 out.println(" Attrib data points: " + aData.numInstances());
845
846
847 out.println(" Attrib Classifier Weight: " + attrClassWeight);
848
849 out.println(" Domain data points: " + dData.numInstances());
850
851
852 out.println(" Domain Classifier Weight: " + domClassWeight);
853
854
855 }
856
857 Classifier vClassifier = null, aClassifier = null, dClassifier = null;
858
859
860
861
862
863
864
865
866
867 vClassifier = vDataGroup.classify();
868 aClassifier = aDataGroup.classify();
869 dClassifier = dDataGroup.classify();
870
871
872 if (DUMP_CLASSIFIER_INFO) {
873 String baseName = solver.getBaseName()+"_rule"+ir.id;
874 if (vClassifier != null)
875 dumpClassifierInfo(baseName+"_vclassifier", vClassifier, vData);
876 if (aClassifier != null)
877 dumpClassifierInfo(baseName+"_aclassifier", aClassifier, aData);
878 if (dClassifier != null)
879 dumpClassifierInfo(baseName+"_dclassifier", dClassifier, dData);
880 try {
881 out_t = new PrintStream(new FileOutputStream(baseName+"_trials"));
882 } catch (IOException x) {
883 solver.err.println("Error while opening file: "+x);
884 }
885 } else {
886 out_t = null;
887 }
888
889 if (TRACE > 2) {
890 out.println("Var classifier: " + vClassifier);
891 out.println("Attrib classifier: " + aClassifier);
892 out.println("Domain classifier: " + dClassifier);
893 }
894
895 double [][] bucketmeans = getBucketMeans(vDis, aDis, dDis);
896
897 Collection sel = null;
898 Collection candidates = null;
899 if(chosenOne == null){
900 Collection triedOrders = returnBest ? new LinkedList() : getTriedOrders((BDDInferenceRule) ir, opNum);
901 if(ec != null){
902 triedOrders.addAll(ec.trials.keySet());
903
904 }
905 Object object = generateCandidateSet( ir, allVars, bucketmeans,
906 vDataGroup, aDataGroup, dDataGroup,
907 triedOrders, returnBest);
908
909
910
911
912 if(object == null) return null;
913 else if(object instanceof Collection)
914 candidates = (Collection) object;
915 else if(object instanceof TrialGuess)
916 return (TrialGuess) object;
917 }else {
918 sel = Collections.singleton(chosenOne);
919 }
920 boolean force = (ec != null && ec.getNumTrials() < 2) ||
921 vData.numInstances() < INITIAL_VAR_SET ||
922 aData.numInstances() < INITIAL_ATTRIB_SET ||
923 dData.numInstances() < INITIAL_DOM_SET;
924
925 if (!returnBest)
926 sel = selectOrder(candidates, vData, aData, dData, ir, force);
927
928 if (sel == null || sel.isEmpty()) return null;
929 Order o_v = (Order) sel.iterator().next();
930 try {
931 OrderTranslator v2a = new VarToAttribTranslator(ir);
932 OrderTranslator a2d = AttribToDomainTranslator.INSTANCE;
933 double vClass = 0, aClass = 0, dClass = 0;
934 if (vClassifier != null) {
935 OrderInstance vInst = OrderInstance.construct(o_v, vData);
936 vClass = vClassifier.classifyInstance(vInst);
937 }
938 Order o_a = v2a.translate(o_v);
939 if (aClassifier != null) {
940 OrderInstance aInst = OrderInstance.construct(o_a, aData);
941 aClass = aClassifier.classifyInstance(aInst);
942 }
943 Order o_d = a2d.translate(o_a);
944 if (dClassifier != null) {
945 OrderInstance dInst = OrderInstance.construct(o_d, dData);
946 dClass = dClassifier.classifyInstance(dInst);
947 }
948 int vi = (int) vClass, ai = (int) aClass, di = (int) dClass;
949 double vScore = 0, aScore = 0, dScore = 0;
950 if (vi < bucketmeans[VMEAN_INDEX].length) vScore = bucketmeans[VMEAN_INDEX][vi];
951 if (ai < bucketmeans[AMEAN_INDEX].length) aScore = bucketmeans[AMEAN_INDEX][ai];
952 if (di < bucketmeans[DMEAN_INDEX].length) dScore = bucketmeans[DMEAN_INDEX][di];
953 double score = varClassWeight * vScore;
954 score += attrClassWeight * aScore;
955 score += domClassWeight * dScore;
956 return genGuess(o_v, score, vClass, aClass, dClass, vDis, aDis, dDis);
957 } catch (Exception x) {
958 x.printStackTrace();
959 Assert.UNREACHABLE(x.toString());
960 return null;
961 }
962 }
963
964 public final static int VMEAN_INDEX = 0;
965 public final static int AMEAN_INDEX = 1;
966 public final static int DMEAN_INDEX = 2;
967 public double[][] getBucketMeans(Discretization vDis, Discretization aDis, Discretization dDis){
968
969
970 double[] vBucketMeans = new double[vDis == null ? 0 : vDis.buckets.length];
971 double[] aBucketMeans = new double[aDis == null ? 0 : aDis.buckets.length];
972 double[] dBucketMeans = new double[dDis == null ? 0 : dDis.buckets.length];
973 if(TRACE > 2) out.print("Var Bucket Means: ");
974 for (int i = 0; i < vBucketMeans.length; ++i) {
975 if (vDis.buckets[i].numInstances() == 0)
976 vBucketMeans[i] = Double.MAX_VALUE;
977 else
978 vBucketMeans[i] = vDis.buckets[i].meanOrMode(vDis.buckets[i].classIndex());
979 if(TRACE > 2) out.print(vBucketMeans[i] + " ");
980 }
981 if (TRACE > 2) {
982 out.println();
983 out.print("Attr Bucket Means: ");
984 }
985 for (int i = 0; i < aBucketMeans.length; ++i) {
986 if (aDis.buckets[i].numInstances() == 0)
987 aBucketMeans[i] = Double.MAX_VALUE;
988 else
989 aBucketMeans[i] = aDis.buckets[i].meanOrMode(aDis.buckets[i].classIndex());
990 if(TRACE > 2) out.print(aBucketMeans[i] + " ");
991 }
992 if (TRACE > 2) {
993 out.println();
994 out.print("Domain Bucket Means: ");
995 }
996 for (int i = 0; i < dBucketMeans.length; ++i) {
997 if (dDis.buckets[i].numInstances() == 0)
998 dBucketMeans[i] = Double.MAX_VALUE;
999 else
1000 dBucketMeans[i] = dDis.buckets[i].meanOrMode(dDis.buckets[i].classIndex());
1001 if(TRACE > 2) out.print(dBucketMeans[i] + " ");
1002 }
1003 if(TRACE > 2) out.println();
1004 double [][] means = new double[3][];
1005
1006 means[VMEAN_INDEX] = vBucketMeans;
1007 means[AMEAN_INDEX] = aBucketMeans;
1008 means[DMEAN_INDEX] = dBucketMeans;
1009
1010 return means;
1011 }
1012
1013 public int getCombos(double[][] combos, int start, int numV, int numA, int numD,
1014 double vBuckets, double aBuckets, double dBuckets,
1015 double [][] means,
1016 double maxScore){
1017 double [] vBucketMeans = means[VMEAN_INDEX];
1018 double [] aBucketMeans = means[AMEAN_INDEX];
1019 double [] dBucketMeans = means[DMEAN_INDEX];
1020
1021 int p = 0;
1022 for (int vi = start; vi < numV; ++vi) {
1023 for (int ai = start; ai < numA; ++ai) {
1024 for (int di = 0; di < numD; ++di) {
1025 double vScore, aScore, dScore;
1026 double nullScore = 1;
1027 if (vi == -1) vScore = nullScore;
1028 else if (vi < vBuckets && vi < vBucketMeans.length) vScore = vBucketMeans[vi];
1029 else vScore = maxScore;
1030 if (ai == -1) aScore = nullScore;
1031 else if (ai < aBuckets && ai < aBucketMeans.length) aScore = aBucketMeans[ai];
1032 else aScore = maxScore;
1033 if (di == -1) dScore = nullScore;
1034 else if (di < dBuckets && di < dBucketMeans.length) dScore = dBucketMeans[di];
1035 else dScore = maxScore;
1036 double score = varClassWeight * vScore;
1037 score += attrClassWeight * aScore;
1038 score += domClassWeight * dScore;
1039 double[] result = new double[] { score, vi==-1?Double.NaN:vi,
1040 ai==-1?Double.NaN:ai,
1041 di==-1?Double.NaN:di };
1042 if (TRACE > 2) {
1043 out.println("Score for v="+vi+" a="+ai+" d="+di+": "+format(score));
1044 }
1045 combos[p++] = result;
1046 }
1047 }
1048 }
1049 Arrays.sort(combos, 0, p, new Comparator() {
1050 public int compare(Object arg0, Object arg1) {
1051 double[] a = (double[]) arg0;
1052 double[] b = (double[]) arg1;
1053 return FindBestDomainOrder.compare(a[0], b[0]);
1054 }
1055 });
1056
1057 return p;
1058 }
1059
1060
1061 public Object generateCandidateSet(InferenceRule ir, List allVars, double [][] means,
1062 TrialDataGroup vDataGroup,
1063 TrialDataGroup aDataGroup, TrialDataGroup dDataGroup, Collection triedOrders, boolean returnBest){
1064
1065 double [] vBucketMeans = means[VMEAN_INDEX];
1066 double [] aBucketMeans = means[AMEAN_INDEX];
1067 double [] dBucketMeans = means[DMEAN_INDEX];
1068
1069
1070 MultiMap a2v, d2v;
1071 a2v = new GenericMultiMap();
1072 d2v = new GenericMultiMap();
1073 for (Iterator i = allVars.iterator(); i.hasNext();) {
1074 Variable v = (Variable) i.next();
1075 Attribute a = (Attribute) ir.getAttribute(v);
1076 if (a != null) {
1077 a2v.add(a, v);
1078 d2v.add(a.getDomain(), v);
1079 }
1080 }
1081
1082 Set candidates = null;
1083 boolean addNullValues = !returnBest;
1084
1085 if (!returnBest) candidates = new LinkedHashSet();
1086 Collection never = neverAgain.getValues(ir);
1087
1088 Discretization vDis = vDataGroup.getDiscretization();
1089 Discretization aDis = aDataGroup.getDiscretization();
1090 Discretization dDis = dDataGroup.getDiscretization();
1091
1092 int end = 5;
1093
1094 int vBuckets = vDis == null ? 1 : vDis.buckets.length / 2 + 1;
1095 int aBuckets = aDis == null ? 1 : aDis.buckets.length / 2 + 1;
1096 int dBuckets = dDis == null ? 1 : dDis.buckets.length / 2 + 1;
1097 double max = (vBucketMeans.length != 0 ? vBucketMeans[vBucketMeans.length-1] : 0);
1098 max += (aBucketMeans.length != 0 ? aBucketMeans[aBucketMeans.length-1] : 0);
1099 max += (dBucketMeans.length != 0 ? dBucketMeans[dBucketMeans.length-1] : 0);
1100 boolean[][][] done = new boolean[vBuckets+1][aBuckets+1][dBuckets+1];
1101 outermost:
1102 while (candidates == null || candidates.size() < CANDIDATE_SET_SIZE) {
1103 int numV = Math.min(end, vBuckets);
1104 int numA = Math.min(end, aBuckets);
1105 int numD = Math.min(end, dBuckets);
1106 if (true && end > vBuckets && end > aBuckets && end > dBuckets) {
1107
1108 numV++; numA++; numD++;
1109 }
1110 int maxNum = addNullValues ? ((numV+1)*(numA+1)*(numD+1)) : (numV*numA*numD);
1111 double[][] combos = new double[maxNum][];
1112 int start = addNullValues ? -1 : 0;
1113 int p = getCombos(combos,start,numV, numA, numD, vBuckets, aBuckets, dBuckets,
1114 means, max);
1115
1116 for (int z = 0; z < p; ++z) {
1117 double bestScore = combos[z][0];
1118 double vClass = combos[z][1]; int vi = (int) vClass;
1119 double aClass = combos[z][2]; int ai = (int) aClass;
1120 double dClass = combos[z][3]; int di = (int) dClass;
1121
1122 if (vi == numV-1 && end <= vBuckets ||
1123 ai == numA-1 && end <= aBuckets ||
1124 di == numD-1 && end <= dBuckets) {
1125 if (TRACE > 1) out.println("reached end ("+vi+","+ai+","+di+"), trying again with a higher cutoff.");
1126 break;
1127 }
1128 if (!Double.isNaN(vClass) && !Double.isNaN(aClass) && !Double.isNaN(dClass)) {
1129 if (done[vi][ai][di]) continue;
1130 done[vi][ai][di] = true;
1131 } else {
1132 addNullValues = false;
1133 }
1134 if (vi == vBuckets) vClass = -1;
1135 if (ai == aBuckets) aClass = -1;
1136 if (di == dBuckets) dClass = -1;
1137 if (out_t != null) out_t.println("v="+vClass+" a="+aClass+" d="+dClass+": "+format(bestScore));
1138 Collection ocss = tryConstraints(vDataGroup, vClass, aDataGroup, aClass, dDataGroup, dClass, a2v, d2v);
1139
1140 if (ocss == null || ocss.isEmpty()) {
1141 if (out_t != null) out_t.println("Constraints cannot be combined.");
1142 continue;
1143 }
1144
1145 for (Iterator i = ocss.iterator(); i.hasNext(); ) {
1146 OrderConstraintSet ocs = (OrderConstraintSet) i.next();
1147 if (out_t != null) out_t.println("Constraints: "+ocs);
1148 if (returnBest) {
1149 TrialGuess guess = genGuess(ocs, bestScore, allVars, bestScore, vClass, aClass, dClass,
1150 vDis, aDis, dDis, triedOrders,
1151 if (guess != null) {
1152 if (TRACE > 1) out.println("Best Guess: "+guess);
1153 return guess;
1154 }
1155 } else {
1156
1157
1158 genOrders(ocs, allVars, triedOrders, never, candidates);
1159 if (candidates.size() >= CANDIDATE_SET_SIZE) break outermost;
1160 }
1161 }
1162 }
1163 if (end > vBuckets && end > aBuckets && end > dBuckets) {
1164 if (TRACE > 1) out.println("Reached end, no more possible guesses!");
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185 if (returnBest) {
1186 return null;
1187 }
1188 break outermost;
1189 }
1190 end *= 2;
1191 if (TRACE > 1) out.println("Cutoff is now "+end);
1192 }
1193
1194
1195 return candidates;
1196 }
1197 public static int CANDIDATE_SET_SIZE = Integer.parseInt(SystemProperties.getProperty("candidateset", "500"));
1198 public static int SAMPLE_SIZE = 1;
1199 public static double UNCERTAINTY_THRESHOLD = Double.parseDouble(SystemProperties.getProperty("uncertainty", ".25"));
1200 public static boolean WEIGHT_UNCERTAINTY_SAMPLE = false;
1201 public static double VCENT = .5, ACENT = .5, DCENT = 1;
1202 static CandidateSampler candidateSetSampler = new UncertaintySampler(SAMPLE_SIZE, UNCERTAINTY_THRESHOLD, VCENT, ACENT, DCENT);
1203
1204 public Collection selectOrder(Collection orders,
1205 TrialInstances vData, TrialInstances aData, TrialInstances dData, InferenceRule ir, boolean force) {
1206 Assert._assert(orders != null);
1207 if(orders.size() == 0){
1208 if(TRACE > 1) out.println("Size of candidate set is 0. No orders to select from");
1209 return null;
1210 }
1211 if (TRACE > 1) out.println("Selecting an order from a candidate set of "+orders.size()+" orders.");
1212 if (TRACE > 2) out.println("Orders: "+orders);
1213
1214 return candidateSetSampler.sample(orders, vData, aData, dData, ir, force);
1215 }
1216
1217
1218
1219 /***
1220 * Returns all the orders that have been tried on a particular rule update
1221 * (including those in previous runs).
1222 *
1223 * @param rule
1224 * @return collection of tried orders
1225 */
1226 Collection getTriedOrders(BDDInferenceRule rule, int opNumber){
1227 Collection triedOrders = new LinkedList();
1228 for(Iterator it = allTrials.iterator(); it.hasNext(); ){
1229 EpisodeCollection ec = (EpisodeCollection) it.next();
1230 if(ec.getRule(solver) == rule && ec.getUpdateCount() == rule.updateCount && ec.getOpNumber() == opNumber)
1231 triedOrders.addAll(ec.trials.keySet());
1232 }
1233 if(TRACE > 2) out.println("Tried Orders: " + triedOrders);
1234 return triedOrders;
1235 }
1236 static void genOrders(OrderConstraintSet ocs, List allVars, Collection already, Collection never, Collection result) {
1237 if (out_t != null) out_t.println("Generating orders for "+allVars);
1238 List orders;
1239 int nOrders = ocs.approxNumOrders(allVars.size());
1240 if (nOrders > CANDIDATE_SET_SIZE*20) {
1241 if (out_t != null) out_t.println("Too many possible orders ("+nOrders+")! Using random sampling.");
1242 orders = new LinkedList();
1243 for (int i = 0; i < CANDIDATE_SET_SIZE; ++i) {
1244 orders.add(ocs.generateRandomOrder(allVars));
1245 }
1246 } else {
1247 if (out_t != null) out_t.println("Estimated "+nOrders+" orders.");
1248 orders = ocs.generateAllOrders(allVars);
1249 }
1250 for (Iterator m = orders.iterator(); m.hasNext(); ) {
1251 Order best = (Order) m.next();
1252 if (never.contains(best)) {
1253 if (out_t != null) out_t.println("Skipped order "+best+" because it has blown up before.");
1254 continue;
1255 }
1256 if (already == null || !already.contains(best)) {
1257 if (out_t != null) out_t.println("Adding to candidate set: "+best);
1258 result.add(best);
1259 if (result.size() > CANDIDATE_SET_SIZE) {
1260 if (out_t != null) out_t.println("Candidate set full.");
1261 return;
1262 }
1263 } else {
1264 if (out_t != null) out_t.println("We have already tried order "+best);
1265 }
1266 }
1267 }
1268
1269 Collection
1270 TrialDataGroup aDataGroup, TrialDataGroup dDataGroup){
1271 Discretization vDis = vDataGroup.getDiscretization();
1272 Discretization aDis = aDataGroup.getDiscretization();
1273 Discretization dDis = dDataGroup.getDiscretization();
1274
1275 double [][] means = getBucketMeans(vDis,aDis,dDis);
1276 int numVBuckets = means[VMEAN_INDEX].length;
1277 int numABuckets = means[AMEAN_INDEX].length;
1278 int numDBuckets = means[DMEAN_INDEX].length;
1279 int maxNum = numVBuckets * numABuckets * numDBuckets;
1280 double [][] combos = new double[maxNum][];
1281 double [] vBucketMeans = means[VMEAN_INDEX];
1282 double [] aBucketMeans = means[AMEAN_INDEX];
1283 double [] dBucketMeans = means[DMEAN_INDEX];
1284 double maxScore = (vBucketMeans.length != 0 ? vBucketMeans[vBucketMeans.length-1] : 0);
1285 maxScore += (aBucketMeans.length != 0 ? aBucketMeans[aBucketMeans.length-1] : 0);
1286 maxScore += (dBucketMeans.length != 0 ? dBucketMeans[dBucketMeans.length-1] : 0);
1287 int numCombos = getCombos(combos, 0, numVBuckets, numABuckets, numDBuckets,
1288 numVBuckets, numABuckets, numDBuckets, means, maxScore);
1289
1290 Collection allPairs = new LinkedList();
1291 int numAdded = 0;
1292 for(int i = 0; i < numCombos && numAdded < num; ++i){
1293 double [] combo = combos[i];
1294 Collection constraints = tryConstraints(vDataGroup, combo[1], aDataGroup, combo[2], dDataGroup,combo[3],a2v,d2v);
1295 if(constraints == null) continue;
1296 for(Iterator jt = constraints.iterator(); jt.hasNext(); ){
1297 allPairs.add(new Pair(new Double(combo[0]) , jt.next()));
1298 ++numAdded;
1299 }
1300
1301 }
1302
1303 return allPairs;
1304 }
1305
1306 static TrialGuess genGuess(Order best, double score,
1307 double vClass, double aClass, double dClass,
1308 Discretization vDis, Discretization aDis, Discretization dDis) {
1309 double vLowerBound, vUpperBound, aLowerBound, aUpperBound, dLowerBound, dUpperBound;
1310 vLowerBound = vUpperBound = aLowerBound = aUpperBound = dLowerBound = dUpperBound = -1;
1311
1312 if (vDis != null && !Double.isNaN(vClass) && vClass != NO_CLASS) {
1313 vLowerBound = vDis.cutPoints == null || vClass <= 0 ? 0 : vDis.cutPoints[(int) vClass - 1];
1314 vUpperBound = vDis.cutPoints == null || vClass == vDis.cutPoints.length ? Double.MAX_VALUE : vDis.cutPoints[(int) vClass];
1315 }
1316 if (aDis != null && !Double.isNaN(aClass) && aClass != NO_CLASS) {
1317 aLowerBound = aDis.cutPoints == null || aClass <= 0 ? 0 : aDis.cutPoints[(int) aClass - 1];
1318 aUpperBound = aDis.cutPoints == null || aClass == aDis.cutPoints.length ? Double.MAX_VALUE : aDis.cutPoints[(int) aClass];
1319 }
1320 if (dDis != null && !Double.isNaN(dClass) && dClass != NO_CLASS) {
1321 dLowerBound = dDis.cutPoints == null || dClass <= 0 ? 0 : dDis.cutPoints[(int) dClass - 1];
1322 dUpperBound = dDis.cutPoints != null || dClass == dDis.cutPoints.length ? Double.MAX_VALUE : dDis.cutPoints[(int) dClass];
1323 }
1324 TrialPrediction prediction = new TrialPrediction(score, vLowerBound,vUpperBound,aLowerBound, aUpperBound,dLowerBound,dUpperBound);
1325 return new TrialGuess(best, prediction);
1326 }
1327
1328 static TrialGuess genGuess(OrderConstraintSet ocs, double score, List allVars, double bestScore,
1329 double vClass, double aClass, double dClass,
1330 Discretization vDis, Discretization aDis, Discretization dDis,
1331
1332 if (out_t != null) out_t.println("Generating orders for "+allVars);
1333
1334 Order best = ocs.generateRandomOrder(allVars);
1335 Iterator m = Collections.singleton(best).iterator();
1336 boolean exhaustive = true;
1337 while (m.hasNext()) {
1338 best = (Order) m.next();
1339 if (never.contains(best)) {
1340 if (out_t != null) out_t.println("Skipped order "+best+" because it has blown up before.");
1341 continue;
1342 }
1343
1344
1345 if(!triedOrders.contains(best)){
1346 if (out_t != null) out_t.println("Using order "+best);
1347 return genGuess(best, score, vClass, aClass, dClass, vDis, aDis, dDis);
1348 } else {
1349 if (out_t != null) out.println("We have already tried order "+best);
1350 }
1351 if (exhaustive) {
1352 List orders = ocs.generateAllOrders(allVars);
1353 m = orders.iterator();
1354 exhaustive = false;
1355 }
1356 }
1357 return null;
1358 }
1359
1360 static Collection
1361
1362
1363
1364
1365 TrialDataGroup vDataGroup, double vClass,
1366 TrialDataGroup aDataGroup, double aClass,
1367 TrialDataGroup dDataGroup, double dClass,
1368 MultiMap a2v, MultiMap d2v) {
1369 Collection results = new LinkedList();
1370 Instances vData = vDataGroup.getTrialInstances();
1371 Instances aData = aDataGroup.getTrialInstances();
1372 Instances dData = dDataGroup.getTrialInstances();
1373 MyId3 v = (MyId3) vDataGroup.getClassifier();
1374 MyId3 a = (MyId3) aDataGroup.getClassifier();
1375 MyId3 d = (MyId3) dDataGroup.getClassifier();
1376 Collection vBestAttribs;
1377 if ((vClass >= 0 || Double.isNaN(vClass)) && v != null)
1378 vBestAttribs = v.getAttribCombos(vData.numAttributes(), vClass);
1379 else
1380 vBestAttribs = Collections.singleton(makeEmptyConstraint());
1381 if (vBestAttribs == null) return null;
1382 for (Iterator v_i = vBestAttribs.iterator(); v_i.hasNext(); ) {
1383 double[] v_c = (double[]) v_i.next();
1384 OrderConstraintSet ocs = new OrderConstraintSet();
1385 boolean v_r = constrainOrder(ocs, v_c, vData, null);
1386 if (!v_r) {
1387 continue;
1388 }
1389 if (out_t != null) out_t.println(" Order constraints (var="+(int)vClass+"): "+ocs);
1390
1391 Collection aBestAttribs;
1392 if ((aClass >= 0 || Double.isNaN(aClass)) && a != null)
1393 aBestAttribs = a.getAttribCombos(aData.numAttributes(), aClass);
1394 else
1395 aBestAttribs = Collections.singleton(makeEmptyConstraint());
1396 if (aBestAttribs == null) continue;
1397 for (Iterator a_i = aBestAttribs.iterator(); a_i.hasNext(); ) {
1398 double[] a_c = (double[]) a_i.next();
1399 OrderConstraintSet ocsBackup = null;
1400 if (a_i.hasNext()) ocsBackup = ocs.copy();
1401 boolean a_r = constrainOrder(ocs, a_c, aData, a2v);
1402 if (!a_r) {
1403 ocs = ocsBackup;
1404 continue;
1405 }
1406 if (out_t != null) out_t.println(" Order constraints (attrib="+(int)aClass+"): "+ocs);
1407
1408 Collection dBestAttribs;
1409 if ((dClass >= 0 || Double.isNaN(dClass)) && d != null)
1410 dBestAttribs = d.getAttribCombos(dData.numAttributes(), dClass);
1411 else
1412 dBestAttribs = Collections.singleton(makeEmptyConstraint());
1413 if (dBestAttribs != null) {
1414 for (Iterator d_i = dBestAttribs.iterator(); d_i.hasNext(); ) {
1415 double[] d_c = (double[]) d_i.next();
1416 OrderConstraintSet ocsBackup2 = null;
1417 if (d_i.hasNext()) ocsBackup2 = ocs.copy();
1418 boolean d_r = constrainOrder(ocs, d_c, dData, d2v);
1419 if (d_r) {
1420 if (out_t != null) out_t.println(" Order constraints (domain="+(int)dClass+"): "+ocs);
1421 results.add(ocs);
1422 }
1423 ocs = ocsBackup2;
1424 }
1425 }
1426 ocs = ocsBackup;
1427 }
1428 }
1429 return results;
1430 }
1431
1432 static double computeScore(int vC, int aC, int dC,
1433 double[] vMeans, double[] aMeans, double[] dMeans,
1434 double vWeight, double aWeight, double dWeight) {
1435 double score = vMeans[vC] * vWeight;
1436 score += aMeans[aC] * aWeight;
1437 score += dMeans[dC] * dWeight;
1438 return score;
1439 }
1440
1441 static double[] makeEmptyConstraint() {
1442 int size = 0;
1443 double[] d = new double[size];
1444 for (int i = 0; i < d.length; ++i) {
1445 d[i] = Double.NaN;
1446 }
1447 return d;
1448 }
1449
1450 static boolean constrainOrder(OrderConstraintSet ocs, double[] c, Instances data, MultiMap map) {
1451 for (int iii = 0; iii < c.length; ++iii) {
1452 if (Double.isNaN(c[iii])) continue;
1453 int k = (int) c[iii];
1454 OrderAttribute oa = (OrderAttribute) data.attribute(iii);
1455 OrderConstraint oc = oa.getConstraint(k);
1456 if (map != null) {
1457 Collection c1 = map.getValues(oc.getFirst());
1458 Collection c2 = map.getValues(oc.getSecond());
1459 boolean any = false;
1460 for (Iterator ii = c1.iterator(); ii.hasNext();) {
1461 Object a = ii.next();
1462 for (Iterator jj = c2.iterator(); jj.hasNext();) {
1463 Object b = jj.next();
1464 if(a.equals(b)) continue;
1465 OrderConstraint cc = OrderConstraint.makeConstraint(oc.getType(), a, b);
1466 boolean r = ocs.constrain(cc);
1467 if (r) {
1468 any = true;
1469 }
1470 }
1471 }
1472 if (!any) {
1473 if (TRACE > 3) out.println("Constraint "+oc+" conflicts with "+ocs);
1474 return false;
1475 }
1476 } else {
1477 boolean r = ocs.constrain(oc);
1478 if (!r) {
1479 if (TRACE > 3) out.println("Constraint "+oc+" conflicts with "+ocs);
1480 return false;
1481 }
1482 }
1483 }
1484 return true;
1485 }
1486
1487 void printGoodOrder(Collection allVars, Instances inst, MyId3 v) {
1488 Collection vBestAttribs = v.getAttribCombos(inst.numAttributes(), 0.);
1489 if (vBestAttribs != null) {
1490 outer:
1491 for (Iterator ii = vBestAttribs.iterator(); ii.hasNext(); ) {
1492 double[] c = (double[]) ii.next();
1493 OrderConstraintSet ocs = new OrderConstraintSet();
1494 for (int iii = 0; iii < c.length; ++iii) {
1495 if (Double.isNaN(c[iii])) continue;
1496 int k = (int) c[iii];
1497 OrderAttribute oa = (OrderAttribute) inst.attribute(iii);
1498 out.println(oa);
1499 OrderConstraint oc = oa.getConstraint(k);
1500 out.println(oc);
1501 boolean r = ocs.constrain(oc);
1502 if (!r) {
1503 if (TRACE > 1) out.println("Constraint "+oc+" conflicts with "+ocs);
1504 continue outer;
1505 }
1506 }
1507 Order o = ocs.generateRandomOrder(allVars);
1508 out.println("Good order: " + o);
1509 }
1510 }
1511 }
1512
1513 TrialGuess evalOrder(Order o, InferenceRule ir) {
1514 List allVars = o.getFlattened();
1515 return tryNewGoodOrder(null, allVars, ir, -2, o, false);
1516 }
1517
1518 public static class OrderSearchElem implements Comparable{
1519 public double pathScore;
1520 public double pathCost;
1521
1522 public OrderConstraintSet ocs;
1523 public int nextRule;
1524 public Collection rulesLeft;
1525 public OrderSearchElem(OrderSearchElem that){
1526 this.pathScore = that.pathScore;
1527 this.ocs = new OrderConstraintSet(that.ocs);
1528 this.nextRule = that.nextRule;
1529 this.pathCost = that.pathCost;
1530 }
1531 public OrderSearchElem(double score, double cost, OrderConstraintSet ocs, int nextRule){
1532 this.pathScore = score;
1533 this.pathCost = cost;
1534 this.ocs = ocs;
1535 this.nextRule= nextRule;
1536 }
1537
1538 public String toString(){
1539 return "[" + pathScore + ", " + ocs + "]";
1540 }
1541
1542
1543
1544 public int compareTo(Object arg0) {
1545 OrderSearchElem that = (OrderSearchElem) arg0;
1546 return Double.compare(this.pathScore, that.pathScore);
1547 }
1548 }
1549
1550 void cache(int ruleNum, BDDInferenceRule rule, OrderConstraintSet[][] cachedConstraints, double[][] cachedScores){
1551 List vars = new LinkedList(rule.getNecessaryVariables());
1552 Object[] arr = vars.toArray();
1553 Arrays.sort(arr, rule.new VarOrderComparator(solver.VARORDER));
1554 vars = Arrays.asList(arr);
1555
1556 if(true){
1557 if(TRACE > 1) out.println("Finding Constraints for: " + rule);
1558 OrderTranslator t = new MapBasedTranslator(rule.variableToBDDDomain);
1559 EpisodeCollection tc = new EpisodeCollection(rule, 0);
1560
1561 boolean initialized = false;
1562
1563 for (int i = 0; i < NUM_BEST_ORDERS_PER_RULE; ++i) {
1564 if(!initialized){
1565 cachedConstraints[ruleNum] = new OrderConstraintSet[NUM_BEST_ORDERS_PER_RULE];
1566 cachedScores[ruleNum] = new double[NUM_BEST_ORDERS_PER_RULE];
1567 initialized = true;
1568 }
1569 TrialGuess tg = tryNewGoodOrder(tc, vars, rule, -2, null, true);
1570 if (tg == null) break;
1571 OrderConstraintSet newOcs = new OrderConstraintSet();
1572 newOcs.constrain(t.translate(tg.order), null);
1573 cachedConstraints[ruleNum][i] = newOcs;
1574 cachedScores[ruleNum][i] = tg.prediction.score;
1575 tc.addTrial(tg.order, null, 0, System.currentTimeMillis());
1576 }
1577 }else{
1578 MultiMap a2v, d2v;
1579 a2v = new GenericMultiMap();
1580 d2v = new GenericMultiMap();
1581 for (Iterator i = vars.iterator(); i.hasNext();) {
1582 Variable v = (Variable) i.next();
1583 Attribute a = (Attribute) rule.getAttribute(v);
1584 if (a != null) {
1585 a2v.add(a, v);
1586 d2v.add(a.getDomain(), v);
1587 }
1588 }
1589 TrialDataGroup vDataGroup = dataRepository.getVariableDataGroup(rule,vars);
1590 vDataGroup.discretize(.5);
1591 vDataGroup.classify();
1592 TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(rule,vars);
1593 aDataGroup.discretize(.25);
1594 vDataGroup.classify();
1595 TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(rule,vars);
1596 dDataGroup.threshold(2);
1597 dDataGroup.classify();
1598 Collection pairs = genConstaints(NUM_BEST_ORDERS_PER_RULE,a2v,d2v,rule,vDataGroup,aDataGroup,dDataGroup);
1599 cachedConstraints[ruleNum] = new OrderConstraintSet[NUM_BEST_ORDERS_PER_RULE];
1600 cachedScores[ruleNum] = new double[NUM_BEST_ORDERS_PER_RULE];
1601 int i = 0;
1602 for(Iterator it = pairs.iterator(); it.hasNext() && i < NUM_BEST_ORDERS_PER_RULE; ++i){
1603 Pair pair = (Pair) it.next();
1604 double score = ((Double) pair.get(0)).doubleValue();
1605 OrderConstraintSet constraints = (OrderConstraintSet) pair.get(1);
1606 cachedConstraints[ruleNum][i] = constraints.translate(rule.variableToBDDDomain);
1607 cachedScores[ruleNum][i] = score;
1608 }
1609 }
1610 }
1611
1612 static String MAX_CON_ORDERS = System.getProperty("considertrials");
1613 static int MAX_GEN_ORDERS = 100;
1614 static final int NUM_BEST_ORDERS_PER_RULE = 3;
1615 void myPrintBestBDDOrders(StringBuffer sb, Collection domains,List rules) {
1616 if(rules.size() == 0) return;
1617 Collection visitedElems = new LinkedList();
1618 double [][] cachedScores = new double[rules.size()][] ;
1619 OrderConstraintSet [][] cachedConstraints = new OrderConstraintSet[rules.size()][];
1620 Queue queue = new StackQueue();
1621 int numPrintedOrders = 0;
1622 int nodes = 0, maxQueueSize = 0;
1623 long allRulesTime = 0;
1624 TrialDataRepository repository = MAX_CON_ORDERS == null ?
1625 dataRepository : dataRepository.reduceByNumTrials(Integer.parseInt(MAX_CON_ORDERS));
1626
1627 for(Iterator it = rules.iterator(); it.hasNext(); )
1628 allRulesTime += ((BDDInferenceRule) it.next()).totalTime;
1629
1630 OrderSearchElem first = new OrderSearchElem(0,0, new OrderConstraintSet(),0);
1631 queue.offer(first);
1632
1633 while(!queue.isEmpty() && numPrintedOrders < MAX_GEN_ORDERS){
1634 maxQueueSize = Math.max(maxQueueSize, queue.size());
1635 OrderSearchElem elem = (OrderSearchElem) queue.poll();
1636 Assert._assert(elem != null);
1637 if(elem.nextRule >= rules.size() || (elem.rulesLeft != null && elem.rulesLeft.isEmpty()) || elem.ocs.onlyOneOrder(domains.size())){
1638 if(TRACE > 1){
1639 out.println("No more rules or constraints on this path");
1640 out.println("Generating orders for: " + elem.ocs);
1641 }
1642 Collection orders;
1643 if (elem.ocs.approxNumOrders(domains.size()) > MAX_GEN_ORDERS) {
1644 if(TRACE > 1) out.println("More than " + MAX_GEN_ORDERS + " orders. Dumping...random sample");
1645 orders = new HashSet();
1646 for (int i = 0; i < 5; ++i) {
1647 orders.add(elem.ocs.generateRandomOrder(domains));
1648 }
1649 } else {
1650 orders = elem.ocs.generateAllOrders(domains);
1651 }
1652 for (Iterator j = orders.iterator(); j.hasNext(); ) {
1653 Order o = (Order) j.next();
1654 sb.append("Score "+format(elem.pathScore, 5)+": "+o.toVarOrderString(null));
1655 sb.append('\n');
1656 }
1657 sb.append("-\n");
1658 numPrintedOrders += orders.size();
1659 continue;
1660 }
1661 ++nodes;
1662
1663 if(TRACE > 3) out.println("Expanding: " + elem);
1664
1665 BDDInferenceRule rule = (BDDInferenceRule) rules.get(elem.nextRule);
1666 boolean cached = cachedConstraints[elem.nextRule] != null;
1667 if(!cached) cache(elem.nextRule, rule, cachedConstraints, cachedScores);
1668 for (int i = 0; i < NUM_BEST_ORDERS_PER_RULE; ++i) {
1669 OrderConstraintSet constraints = cachedConstraints[elem.nextRule][i];
1670 OrderSearchElem newElem = new OrderSearchElem(elem);
1671 if(constraints == null){
1672 if(i == 0) {
1673 ++newElem.nextRule;
1674 queue.offer(newElem);
1675
1676 }
1677 break;
1678 }
1679 double constraintScore = cachedScores[elem.nextRule][i];
1680 ++newElem.nextRule;
1681 Collection invalidConstraints = new LinkedList();
1682 if(TRACE > 3) out.println("Adding constraints: " + constraints);
1683 boolean worked = newElem.ocs.constrain(constraints, invalidConstraints);
1684
1685 newElem.pathCost += invalidConstraints.size() * (rule.totalTime / constraintScore) ;
1686 newElem.pathScore = newElem.pathCost;
1687
1688
1689 if(TRACE > 3) out.println("Couldn't add: " + invalidConstraints);
1690 queue.offer(newElem);
1691
1692
1693 }
1694
1695
1696
1697
1698
1699 }
1700 out.println("Max queue size: " + maxQueueSize + " Nodes expanded: " + nodes);
1701 }
1702
1703 void printBestBDDOrders(StringBuffer sb, double score, Collection domains, OrderConstraintSet ocs,
1704 MultiMap rulesToTrials, List rules) {
1705 if (rules == null || rules.isEmpty()) {
1706 if(TRACE > 1) out.println("No more rules, Generating orders");
1707 Collection orders;
1708 if (ocs.approxNumOrders(domains.size()) > 1000) {
1709 if(TRACE > 1) out.println("More than " + MAX_GEN_ORDERS + " orders. Dumping...random sample");
1710 orders = new LinkedList();
1711 for (int i = 0; i < 5; ++i) {
1712 orders.add(ocs.generateRandomOrder(domains));
1713 }
1714 } else {
1715 if(TRACE > 1) out.println("Generating orders from constraints: " + ocs);
1716 orders = ocs.generateAllOrders(domains);
1717 }
1718 for (Iterator j = orders.iterator(); j.hasNext(); ) {
1719 Order o = (Order) j.next();
1720 sb.append("Score "+format(score)+": "+o.toVarOrderString(null));
1721 sb.append('\n');
1722 }
1723 return;
1724 }
1725 if (!ocs.onlyOneOrder(domains.size())) {
1726 InferenceRule ir = (InferenceRule) rules.get(0);
1727 List rest = rules.subList(1, rules.size());
1728 if (ir instanceof BDDInferenceRule && rulesToTrials.containsKey(ir)) {
1729
1730 BDDInferenceRule bddir = (BDDInferenceRule) ir;
1731 if(TRACE > 1) {
1732 out.println("Generating constraints for rule:\n" + ir.toString());
1733 out.println("Total rule run time: " + bddir.totalTime);
1734 }
1735 OrderTranslator t = new MapBasedTranslator(bddir.variableToBDDDomain);
1736 EpisodeCollection tc = new EpisodeCollection(bddir, 0);
1737 for (int i = 0; i < 5; ++i) {
1738 TrialGuess tg = tryNewGoodOrder(tc, new ArrayList(bddir.necessaryVariables), bddir, -2, null, true);
1739 if (tg == null) break;
1740 OrderConstraintSet ocs2 = new OrderConstraintSet(ocs);
1741 Order bddOrder = t.translate(tg.order);
1742 out.println("Adding constraints for: " + bddOrder);
1743 Collection invalidConstraints = new LinkedList();
1744 boolean worked = ocs2.constrain(bddOrder, invalidConstraints);
1745 double score2 = tg.prediction.score * (bddir.totalTime+1) / 1000;
1746
1747
1748
1749
1750
1751
1752
1753 if (worked) {
1754
1755 printBestBDDOrders(sb, score + score2, domains, ocs2, rulesToTrials, rest);
1756 }else{
1757 out.println("Couldn't add constraints: " + invalidConstraints);
1758 }
1759 tc.addTrial(tg.order, null, 0, System.currentTimeMillis());
1760
1761 }
1762 } else {
1763 printBestBDDOrders(sb, score, domains, ocs, rulesToTrials, rest);
1764 }
1765 }
1766
1767
1768
1769 out.println("Can't add more constraints: " + ocs);
1770 out.println("Left over rules (" + rules.size() + ": " + rules);
1771 Order o = ocs.generateRandomOrder(domains);
1772 for (Iterator k = rules.iterator(); k.hasNext(); ) {
1773 InferenceRule ir = (InferenceRule) k.next();
1774 if(!(ir instanceof BDDInferenceRule)) continue;
1775 BDDInferenceRule bddir = (BDDInferenceRule) ir;
1776 Order o2;
1777 if (false) {
1778 MultiMap d2v = new GenericMultiMap();
1779 for (Iterator a = bddir.variableToBDDDomain.entrySet().iterator(); a.hasNext(); ) {
1780 Map.Entry e = (Map.Entry) a.next();
1781 d2v.add(e.getValue(), e.getKey());
1782 }
1783 o2 = new MapBasedTranslator(d2v).translate(o);
1784 } else {
1785 Map d2v = new HashMap();
1786 for (Iterator a = bddir.variableToBDDDomain.entrySet().iterator(); a.hasNext(); ) {
1787 Map.Entry e = (Map.Entry) a.next();
1788 d2v.put(e.getValue(), e.getKey());
1789 }
1790 o2 = new MapBasedTranslator(d2v).translate(o);
1791 }
1792 TrialGuess tg = tryNewGoodOrder(null, new ArrayList(bddir.necessaryVariables), bddir, -2, o2, true);
1793 score += tg.prediction.score * (bddir.totalTime+1) / 1000;
1794 }
1795 sb.append("Score "+format(score)+": "+o.toVarOrderString(null));
1796 sb.append('\n');
1797 }
1798
1799 public Set getVisitedRules(){
1800 Set visitedRules = new HashSet();
1801 for(Iterator it = allTrials.iterator(); it.hasNext(); ){
1802 EpisodeCollection tc = (EpisodeCollection) it.next();
1803 visitedRules.add(tc.getRule(solver));
1804 }
1805
1806 if(TRACE > 2) out.println("Visited Rules: " + visitedRules);
1807 return visitedRules;
1808 }
1809
1810 public void printBestBDDOrders() {
1811 MultiMap ruleToTrials = new GenericMultiMap();
1812 for (Iterator i = allTrials.iterator(); i.hasNext(); ) {
1813 EpisodeCollection tc = (EpisodeCollection) i.next();
1814 ruleToTrials.add(tc.getRule(solver), tc);
1815 }
1816
1817
1818 SortedSet sortedRules = new TreeSet(new Comparator() {
1819 public int compare(Object o1, Object o2) {
1820 if (o1 == o2) return 0;
1821 if (o1 instanceof NumberingRule) return -1;
1822 if (o2 instanceof NumberingRule) return 1;
1823 BDDInferenceRule r1 = (BDDInferenceRule) o1;
1824 BDDInferenceRule r2 = (BDDInferenceRule) o2;
1825 long diff = r2.totalTime - r1.totalTime;
1826
1827 if (diff != 0)
1828 return (int) diff;
1829 return r1.id - r2.id;
1830 }
1831 });
1832 sortedRules.addAll(filterRules(solver.rules));
1833 ArrayList list = new ArrayList(sortedRules);
1834 for (Iterator i = list.iterator (); i.hasNext(); ) {
1835 BDDInferenceRule rule = (BDDInferenceRule) i.next();
1836 System.out.println(bestOrders(rule, 5));
1837 }
1838 Collection domains = new FlattenedCollection(solver.getBDDDomains().values());
1839 out.println("BDD Domains: "+domains);
1840 OrderConstraintSet ocs = new OrderConstraintSet();
1841 StringBuffer sb = new StringBuffer();
1842
1843 myPrintBestBDDOrders(sb, domains, list);
1844 out.println(sb);
1845 }
1846
1847 /***
1848 * Generate the k best orders for the given inference rule and
1849 * put the result in a string.
1850 *
1851 * @param rule inference rule
1852 * @param k number of orders
1853 * @return string result
1854 */
1855 public String bestOrders(BDDInferenceRule rule, int k) {
1856 List vars = new LinkedList(rule.getNecessaryVariables());
1857 Object[] arr = vars.toArray();
1858 Arrays.sort(arr, rule.new VarOrderComparator(solver.VARORDER));
1859 vars = Arrays.asList(arr);
1860 EpisodeCollection tc = new EpisodeCollection(rule, 0);
1861 StringBuffer sb = new StringBuffer();
1862 sb.append(rule.toString());
1863 sb.append(Strings.lineSep);
1864 for (int i = 1; i <= k; ++i) {
1865 TrialGuess tg = tryNewGoodOrder(tc, vars, rule, -2, null, true);
1866 if (tg == null) break;
1867 sb.append(" Order #").append(i).append(": ");
1868 sb.append(tg.toString());
1869 }
1870 return sb.toString();
1871 }
1872
1873 static Collection filterRules(Collection rules){
1874 Collection filteredRules = new LinkedList();
1875 for(Iterator it = rules.iterator(); it.hasNext(); ){
1876 InferenceRule rule = (InferenceRule) it.next();
1877 if(rule instanceof BDDInferenceRule) filteredRules.add(rule);
1878 }
1879 return filteredRules;
1880 }
1881 public void printBestTrials() {
1882 MultiMap ruleToTrials = new GenericMultiMap();
1883 for (Iterator i = allTrials.iterator(); i.hasNext(); ) {
1884 EpisodeCollection tc = (EpisodeCollection) i.next();
1885 ruleToTrials.add(tc.getRule(solver), tc);
1886 }
1887
1888 SortedSet sortedRules = new TreeSet(new Comparator() {
1889 public int compare(Object o1, Object o2) {
1890 if (o1 == o2) return 0;
1891 BDDInferenceRule r1 = (BDDInferenceRule) o1;
1892 BDDInferenceRule r2 = (BDDInferenceRule) o2;
1893
1894 long diff = r2.totalTime - r1.totalTime;
1895
1896 if (diff != 0)
1897 return (int) diff;
1898 return r1.id - r2.id;
1899 }
1900 });
1901 sortedRules.addAll(ruleToTrials.keySet());
1902
1903 for (Iterator i = sortedRules.iterator(); i.hasNext(); ) {
1904 BDDInferenceRule ir = (BDDInferenceRule) i.next();
1905 Map scoreboard = new HashMap();
1906 for (Iterator j = ruleToTrials.getValues(ir).iterator(); j.hasNext(); ) {
1907 EpisodeCollection tc = (EpisodeCollection) j.next();
1908 TrialInfo ti = tc.getMinimum();
1909 if (ti == null || ti.isMax()) continue;
1910 long[] score = (long[]) scoreboard.get(ti.order);
1911 if (score == null) scoreboard.put(ti.order, score = new long[2]);
1912 score[0]++;
1913 score[1] += ti.cost;
1914 }
1915
1916 if (scoreboard.isEmpty()) continue;
1917
1918 SortedSet sortedTrials = new TreeSet(new Comparator() {
1919 public int compare(Object o1, Object o2) {
1920 long[] counts1 = (long[]) ((Map.Entry) o1).getValue();
1921 long[] counts2 = (long[]) ((Map.Entry) o2).getValue();
1922 long diff = counts2[0] - counts1[0];
1923 if (diff != 0)
1924 return (int) diff;
1925 diff = counts2[1] - counts1[1];
1926 if (diff != 0)
1927 return (int) diff;
1928 Order order1 = (Order) ((Map.Entry) o1).getKey();
1929 Order order2 = (Order) ((Map.Entry) o2).getKey();
1930 return order1.compareTo(order2);
1931 }
1932 });
1933 sortedTrials.addAll(scoreboard.entrySet());
1934
1935 out.println("For rule"+ir.id+": "+ir);
1936 for (Iterator it = sortedTrials.iterator(); it.hasNext();) {
1937 Map.Entry entry = (Map.Entry) it.next();
1938 Order order = (Order) entry.getKey();
1939 long[] counts = (long[]) entry.getValue();
1940 double aveTime = (double)counts[1] / (double) counts[0];
1941 String bddString = order.toVarOrderString(ir.variableToBDDDomain);
1942 out.println(order + " won " + counts[0] + " time(s), average winning time of "+format(aveTime)+" ms");
1943 out.println(" BDD order: "+bddString);
1944 }
1945 out.println();
1946 }
1947
1948 }
1949
1950 public void printTrialsDistro() {
1951 printTrialsDistro(allTrials, solver);
1952 }
1953
1954 public static void printTrialsDistro(Collection trials, Solver solver) {
1955 Map orderToCounts = new HashMap();
1956 final int numRules = solver.getRules().size();
1957 int total = 0, distinct = 0;
1958 for (Iterator it = trials.iterator(); it.hasNext();) {
1959 EpisodeCollection tc = (EpisodeCollection) it.next();
1960 Assert._assert(tc != null);
1961 for (Iterator jt = tc.getTrials().iterator(); jt.hasNext();) {
1962 TrialInfo ti = (TrialInfo) jt.next();
1963 Order order = ti.order;
1964 int[] counts = (int[]) orderToCounts.get(order);
1965 if (counts == null) {
1966 counts = new int[numRules + 1];
1967 orderToCounts.put(order, counts);
1968 ++distinct;
1969 }
1970 ++counts[tc.getRule(solver).id];
1971
1972 ++counts[numRules];
1973 }
1974 total += tc.getNumTrials();
1975 }
1976
1977 SortedSet sortedTrials = new TreeSet(new Comparator() {
1978 public int compare(Object o1, Object o2) {
1979 int[] counts1 = (int[]) ((Map.Entry) o1).getValue();
1980 int[] counts2 = (int[]) ((Map.Entry) o2).getValue();
1981 int diff = counts2[numRules] - counts1[numRules];
1982 if (diff != 0)
1983 return diff;
1984 Order order1 = (Order) ((Map.Entry) o1).getKey();
1985 Order order2 = (Order) ((Map.Entry) o2).getKey();
1986 return order1.compareTo(order2);
1987 }
1988 });
1989
1990 sortedTrials.addAll(orderToCounts.entrySet());
1991 out.println(total + " trials of " + distinct + " distinct orders");
1992 out.println("tried Orders: ");
1993 for (Iterator it = sortedTrials.iterator(); it.hasNext();) {
1994 Map.Entry entry = (Map.Entry) it.next();
1995 Order order = (Order) entry.getKey();
1996 int[] counts = (int[]) entry.getValue();
1997 out.println(order + " tried a total of " + counts[numRules] + " time(s) :");
1998 for (int i = 0; i < counts.length - 1; ++i) {
1999 int count = counts[i];
2000 if (count != 0) {
2001 out.println(" " + count + " time(s) on \n " + solver.getRule(i));
2002 }
2003 }
2004 out.println();
2005 }
2006 }
2007
2008 public static void main(String[] args) throws Exception {
2009 String inputFilename = SystemProperties.getProperty("datalog");
2010 if (args.length > 0) inputFilename = args[0];
2011 if (inputFilename == null) {
2012 return;
2013 }
2014 String solverName = SystemProperties.getProperty("solver", "net.sf.bddbddb.BDDSolver");
2015 Solver s;
2016 s = (Solver) Class.forName(solverName).newInstance();
2017 s.load(inputFilename);
2018
2019 FindBestDomainOrder dis = ((BDDSolver) s).fbo;
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039 dis.printBestBDDOrders();
2040 }
2041
2042
2043 /***
2044 * Run the find best domain order on the given inputs.
2045 *
2046 * @param bdd BDD factory
2047 * @param b1 first input to relprod
2048 * @param b2 second input to relprod
2049 * @param b3 third input to relprod
2050 * @param r1 first rule term
2051 * @param r2 second rule term
2052 * @param vars1 variables of b1
2053 * @param vars2 variables of b2
2054 */
2055 static void findBestDomainOrder(BDDSolver solver, BDDInferenceRule rule, int opNum, BDDFactory bdd, BDD b1, BDD b2, BDDVarSet b3, RuleTerm r1, RuleTerm r2, Collection vars1, Collection vars2) {
2056 Set allVarSet = new HashSet(vars1); allVarSet.addAll(vars2);
2057 allVarSet.removeAll(rule.unnecessaryVariables);
2058 Object[] a = allVarSet.toArray();
2059
2060
2061 Arrays.sort(a, rule.new VarOrderComparator(solver.VARORDER));
2062 List allVars = Arrays.asList(a);
2063
2064 FindBestDomainOrder fbdo = solver.fbo;
2065 if (!fbdo.hasOrdersToTry(allVars, rule)) {
2066 out.println("No more orders to try, skipping find best order for "+vars1+","+vars2);
2067 return;
2068 }
2069 out.println("Finding best order for "+vars1+","+vars2);
2070 long time = System.currentTimeMillis();
2071 Episode ep = fbdo.getNewEpisode(rule, opNum, time, true);
2072 FindBestOrder fbo = new FindBestOrder(solver.BDDNODES, solver.BDDCACHE, solver.BDDNODES / 2, Long.MAX_VALUE, 5000);
2073 try {
2074 fbo.init(b1, b2, b3, BDDFactory.and);
2075 } catch (IOException x) {
2076 solver.err.println("IO Exception occurred: " + x);
2077 fbo.cleanup();
2078 return;
2079 }
2080 out.println("Time to initialize FindBestOrder: "+(System.currentTimeMillis()-time));
2081 int count = BDDInferenceRule.MAX_FBO_TRIALS;
2082 boolean first = true;
2083 long bestTime = Long.MAX_VALUE;
2084 while (--count >= 0) {
2085
2086 TrialGuess guess = fbdo.tryNewGoodOrder(ep, allVars, rule,opNum, first);
2087 if (guess == null || guess.order == null) break;
2088 Order o = guess.order;
2089 String vOrder = o.toVarOrderString(rule.variableToBDDDomain);
2090 out.println("Trying order "+vOrder);
2091 vOrder = solver.fixVarOrder(vOrder, false);
2092 out.println("Complete order "+vOrder);
2093 time = fbo.tryOrder(true, vOrder);
2094 time = Math.min(time, BDDInferenceRule.LONG_TIME);
2095 bestTime = Math.min(time, bestTime);
2096 fbdo.addTrial(rule, allVars,ep, o,guess.prediction, time, System.currentTimeMillis());
2097
2098 if (time >= BDDInferenceRule.LONG_TIME)
2099 fbdo.neverTryAgain(rule, o);
2100 first = false;
2101 }
2102 fbo.cleanup();
2103
2104 fbdo.incorporateTrial(ep);
2105
2106 XMLFactory.dumpXML("fbo.xml", fbdo.toXMLElement());
2107 XMLFactory.dumpXML(solver.TRIALFILE, fbdo.trialsToXMLElement());
2108 }
2109
2110
2111 }