View Javadoc

1   /*
2    * Created on Jan 18, 2005
3    *
4    * TODO To change the template for this generated file go to
5    * Window - Preferences - Java - Code Style - Code Templates
6    */
7   package net.sf.bddbddb.order;
8   
9   import java.util.Collection;
10  import java.util.Comparator;
11  import java.util.HashMap;
12  import java.util.HashSet;
13  import java.util.Iterator;
14  import java.util.LinkedList;
15  import java.util.List;
16  import java.util.Map;
17  import java.util.Set;
18  import java.io.PrintStream;
19  import jwutil.collections.GenericMultiMap;
20  import jwutil.collections.MultiMap;
21  import jwutil.collections.SortedArraySet;
22  import jwutil.util.Assert;
23  import net.sf.bddbddb.BDDSolver;
24  import net.sf.bddbddb.FindBestDomainOrder;
25  import net.sf.bddbddb.InferenceRule;
26  import weka.classifiers.Classifier;
27  import weka.core.FastVector;
28  
29  /***
30   * @author Administrator
31   *
32   * TODO To change the template for this generated type comment go to
33   * Window - Preferences - Java - Code Style - Code Templates
34   */
35  public class TrialDataRepository {
36      
37      Collection allTrials;
38      static int TRACE = FindBestDomainOrder.TRACE;
39      Map /*InferenceRule, TrialDataGroup*/ varDataMap, /*Set, TrialDataGroup */attribDataMap, domainDataMap;
40      MultiMap varListeners, attribListeners, domainListeners;
41      static PrintStream out = FindBestDomainOrder.out;
42      BDDSolver solver;
43      public TrialDataRepository(BDDSolver solver){
44          this.solver = solver;
45          varDataMap = new HashMap();
46          varListeners = new GenericMultiMap();
47          attribDataMap = new HashMap();
48          attribListeners = new GenericMultiMap();
49          domainDataMap = new HashMap();
50          domainListeners = new GenericMultiMap();
51          allTrials = new LinkedList();
52      }
53      
54      public TrialDataRepository(Collection allTrials, BDDSolver solver){
55          this(solver);
56          this.allTrials = allTrials;
57      }
58      
59      public TrialInstances buildVarInstances(InferenceRule ir, List allVars) {
60          FastVector attributes = new FastVector();
61          WekaInterface.addAllPairs(attributes, allVars);
62          attributes.addElement(new weka.core.Attribute("score"));
63          int capacity = 30;
64          OrderTranslator filter = new FilterTranslator(allVars);
65          TrialInstances data = new TrialInstances("Var Ordering Constraints", attributes, capacity);
66          if (allVars.size() <= 1) return data;
67          for (Iterator i = allTrials.iterator(); i.hasNext();) {
68              EpisodeCollection tc2 = (EpisodeCollection) i.next();
69              InferenceRule ir2 = tc2.getRule(solver);
70              if (ir != ir2) continue;
71              addToInstances(data, tc2, filter);
72          }
73          data.setClassIndex(data.numAttributes() - 1);
74          return data;
75      }
76      
77      public TrialInstances buildAttribInstances(InferenceRule ir, List allVars) {
78          Collection allAttribs = VarToAttribMap.convert(allVars, ir);
79          if (TRACE > 1) out.println("Attribs: "+allAttribs);
80          FastVector attributes = new FastVector();
81          WekaInterface.addAllPairs(attributes, allAttribs);
82          attributes.addElement(new weka.core.Attribute("score"));
83          int capacity = 30;
84          TrialInstances data = new TrialInstances("Attribute Ordering Constraints", attributes, capacity);
85          if (allAttribs.size() <= 1) return data;
86          for (Iterator i = allTrials.iterator(); i.hasNext();) {
87              EpisodeCollection tc2 = (EpisodeCollection) i.next();
88              InferenceRule ir2 = tc2.getRule(solver);
89              OrderTranslator t = new VarToAttribTranslator(ir2);
90              t = new OrderTranslator.Compose(t, new FilterTranslator(allAttribs));
91              addToInstances(data, tc2, t);
92          }
93          data.setClassIndex(data.numAttributes() - 1);
94          return data;
95      }
96      
97      public TrialInstances buildDomainInstances(InferenceRule ir, List allVars) {
98          Collection allDomains = AttribToDomainMap.convert(VarToAttribMap.convert(allVars, ir));
99          if (TRACE > 1) out.println("Domains: "+allDomains);
100         FastVector attributes = new FastVector();
101         WekaInterface.addAllPairs(attributes, allDomains);
102         attributes.addElement(new weka.core.Attribute("score"));
103         int capacity = 30;
104         TrialInstances data = new TrialInstances("Domain Ordering Constraints", attributes, capacity);
105         if (allDomains.size() <= 1) return data;
106         for (Iterator i = allTrials.iterator(); i.hasNext();) {
107             EpisodeCollection tc2 = (EpisodeCollection) i.next();
108             InferenceRule ir2 = tc2.getRule(solver);
109             OrderTranslator t = new VarToAttribTranslator(ir2);
110             t = new OrderTranslator.Compose(t, AttribToDomainTranslator.INSTANCE);
111             t = new OrderTranslator.Compose(t, new FilterTranslator(allDomains));
112             addToInstances(data, tc2, t);
113         }
114         data.setClassIndex(data.numAttributes() - 1);
115         return data;
116     }
117     
118     public static void addToInstances(TrialInstances data, EpisodeCollection tc, OrderTranslator t) {
119         if (tc.getNumTrials() == 0) return;
120         double best;
121         if (tc.getMinimum().isMax()) best = 1;
122         else best = (double) tc.getMinimum().cost + 1;
123         for (Iterator j = tc.trials.values().iterator(); j.hasNext();) {
124             TrialInfo ti = (TrialInfo) j.next();
125             double score = (double) (ti.cost + 1) / best;
126             Order o = t == null ? ti.order : t.translate(ti.order);
127             if (o.numberOfElements() <= 1) continue;
128             TrialInstance tinst = TrialInstance.construct(ti, o, score, data);
129             if (tinst != null) data.add(tinst);
130         }
131     }
132     
133 
134  
135     public TrialDataGroup getVariableDataGroup(InferenceRule rule, List variables){
136         TrialDataGroup dataGroup = (TrialDataGroup) varDataMap.get(variables);
137         if(dataGroup == null){
138             dataGroup = new TrialDataGroup.VariableTrialDataGroup(variables, buildVarInstances(rule, variables));
139             varDataMap.put(variables, dataGroup);
140             Collection pairs = WekaInterface.generateAllPairs(variables);
141             for(Iterator it = pairs.iterator(); it.hasNext(); ){
142                 varListeners.add(it.next(), dataGroup); 
143             }
144         }
145         return dataGroup;
146     }
147     
148    
149     public TrialDataGroup getAttribDataGroup(InferenceRule rule, List variables){
150         Set attribs = new HashSet(VarToAttribMap.convert(variables, rule));
151         TrialDataGroup dataGroup = (TrialDataGroup) attribDataMap.get(attribs);
152         if(dataGroup == null){
153             dataGroup = new TrialDataGroup.AttribTrialDataGroup(attribs, buildAttribInstances(rule, variables));
154             attribDataMap.put(attribs, dataGroup);
155             Collection pairs = WekaInterface.generateAllPairs(attribs);
156             for(Iterator it = pairs.iterator(); it.hasNext(); ){
157                 attribListeners.add(it.next(), dataGroup); 
158             }
159         }
160         return dataGroup;
161     }
162     
163     public TrialDataGroup getDomainDataGroup(InferenceRule rule, List variables){
164         List domains = new LinkedList(AttribToDomainMap.convert(VarToAttribMap.convert(variables, rule)));
165         Set domainSet = new HashSet(domains);
166         TrialDataGroup dataGroup = (TrialDataGroup) domainDataMap.get(domains);
167         if(dataGroup == null){
168             dataGroup = new TrialDataGroup.DomainTrialDataGroup(domains, buildDomainInstances(rule, variables) );
169             domainDataMap.put(domains, dataGroup);
170             Collection pairs = WekaInterface.generateAllPairs(domains);
171             for(Iterator it = pairs.iterator(); it.hasNext(); ){
172                 domainListeners.add(it.next(), dataGroup); 
173             }
174         }
175         return dataGroup;
176     }
177     
178     public boolean addTrial(InferenceRule rule, List variables, TrialInfo info){
179         Order o_v = info.order;
180         EpisodeCollection tc = info.getCollection();
181         
182         //boolean changed = varData.update(o_v,info, trialColBest);
183         boolean changed = false;
184         Collection varPairs = WekaInterface.generateAllPairs(variables);
185         Collection notified = new HashSet();
186         for(Iterator it = varPairs.iterator(); it.hasNext(); ){
187             Collection listeners = varListeners.getValues(it.next());                                              
188             Assert._assert(listeners != null);
189             for(Iterator jt = listeners.iterator(); jt.hasNext();){
190                 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
191                 if(!notified.contains(dataGroup)){
192                     changed |= dataGroup.update(o_v, info,tc);
193                     notified.add(dataGroup);
194                 }
195             }
196         }
197         
198         OrderTranslator translator = new VarToAttribTranslator(rule);
199         Order o_a = translator.translate(o_v);
200         Collection attribs = VarToAttribMap.convert(variables, rule);
201         Collection attribPairs = WekaInterface.generateAllPairs(attribs);
202         for(Iterator it = attribPairs.iterator(); it.hasNext(); ){
203             Collection listeners = attribListeners.getValues(it.next());                                              
204             Assert._assert(listeners != null);
205             for(Iterator jt = listeners.iterator(); jt.hasNext();){
206                 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
207                 if(!notified.contains(dataGroup)){
208                     changed |= dataGroup.update(o_a, info,tc);
209                     notified.add(dataGroup);
210                 }
211             }
212         }
213         Order o_d = AttribToDomainTranslator.INSTANCE.translate(o_a);
214         Collection domainPairs = WekaInterface.generateAllPairs(AttribToDomainMap.convert(attribs));
215         for(Iterator it = domainPairs.iterator(); it.hasNext(); ){
216             Collection domListeners = domainListeners.getValues(it.next());
217             Assert._assert(domListeners != null);
218             for(Iterator jt = domListeners.iterator(); jt.hasNext(); ){
219                 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
220                 if(!notified.contains(dataGroup)){
221                     changed |= dataGroup.update(o_d, info, tc);
222                     notified.add(dataGroup);
223                 }
224             }
225         }
226        
227         return changed;
228     }
229     
230     public TrialDataRepository reduceByNumTrials(int numTrials){
231         Collection newAllTrials = new LinkedList();
232         numTrials = Math.min(allTrials.size(), numTrials);
233         SortedArraySet sortedTrials = (SortedArraySet) SortedArraySet.FACTORY.makeSet(
234                 new Comparator(){
235                     public int compare(Object o1, Object o2) {
236                         TrialInfo t1 = (TrialInfo) o1;
237                         TrialInfo t2 = (TrialInfo) o2;
238                         return FindBestDomainOrder.signum(t1.timestamp  - t2.timestamp);
239                     }
240                 });
241         sortedTrials.addAll(allTrials);
242         newAllTrials.addAll(sortedTrials.subList(0, numTrials - 1));
243         return new TrialDataRepository(newAllTrials, this.solver);
244     }
245     
246         
247         
248     public abstract static class TrialDataGroup{
249 
250         public static String CLASSIFIER = "net.sf.bddbddb.order.MyId3";
251         private TrialInstances trialInstances;
252         private TrialInstances trialInstancesCopy;
253         private Discretization discretization;
254         private double discretizeParam = 0;
255         private double thresholdParam = 0;
256         private MultiMap /*EpisodeCollection, Instances*/ trialMap;
257         private Classifier classifier;
258         private double infoSinceClassRebuild, infoSinceDiscRebuild, infoSinceInstances;
259         private double infoThreshold; 
260         protected FilterTranslator filter;
261         protected TrialDataGroup(TrialInstances instances){
262             trialInstances = instances;
263             discretizeParam  = -1;
264             thresholdParam = -1;
265             trialMap = new GenericMultiMap();
266         }
267         
268         /***
269          * @return Returns the classifier.
270          */
271         public Classifier classify() {
272             if(discretizeParam < 0 && thresholdParam < 0)
273                 return null;
274             Assert._assert(discretizeParam < 0 ^ thresholdParam < 0); //kinda weird
275            
276            if(discretizeParam > 0)
277                discretize(discretizeParam);
278            else
279                threshold(thresholdParam);
280            
281           TrialInstances instances = getTrialInstances();
282             classifier = instances.numInstances() > 0 ? WekaInterface.buildClassifier(CLASSIFIER, instances) : null;
283             return classifier;
284         }
285         
286         public Classifier getClassifier(){
287             return classifier;
288         }
289         public void setDiscretizeParam(double discretize){
290             discretizeParam = discretize;
291             thresholdParam = -1;
292         }
293         
294         public void setThresholdParam(double thresholdParam){
295             this.thresholdParam = thresholdParam;
296             discretizeParam = -1;
297         }
298         /***
299          * @return Returns the discretization.
300          */
301         public Discretization discretize(double discretizeFact) {
302             if((discretizeFact != discretizeParam) || (infoSinceDiscRebuild > infoThreshold)){
303                 setDiscretizeParam(discretizeFact);
304                 discretization = getTrialInstances().discretize(discretizeParam);
305                 infoSinceDiscRebuild = 0;
306             }
307             return discretization;
308         }
309         
310         public Discretization getDiscretization(){
311             //Assert._assert(discretization != null && discretizeParam != -1 && (infoSinceDiscRebuild <= infoThreshold));
312             Assert._assert(discretizeParam != -1 || thresholdParam != -1);
313             if(discretizeParam != -1) return discretize(discretizeParam);
314             return threshold(thresholdParam);
315            
316         }
317        
318         public Discretization threshold(double threshold){
319             if((threshold != thresholdParam) || (infoSinceDiscRebuild > infoThreshold)){
320                 setThresholdParam(threshold);
321                 discretization = getTrialInstances().threshold(thresholdParam);
322                 infoSinceDiscRebuild = 0;
323             }
324             return discretization;
325         }
326         
327         /***
328          * @return Returns the instances.
329          */
330         public TrialInstances getTrialInstances() {
331             if(trialInstancesCopy == null || infoSinceInstances > infoThreshold){
332                 trialInstancesCopy = trialInstances.copy();
333                 infoSinceInstances = 0;
334             }
335             return trialInstancesCopy;
336         }
337         public void forceRebuildNext(){
338             infoSinceClassRebuild = Double.POSITIVE_INFINITY;
339             infoSinceDiscRebuild = Double.POSITIVE_INFINITY;
340             infoSinceInstances = Double.POSITIVE_INFINITY;
341         }
342         public boolean update(Order order, TrialInfo info, EpisodeCollection tc){
343             forceRebuildNext();
344             double trialColBest;
345             if (tc.getMinimum().isMax()) trialColBest = 1;
346             else trialColBest = (double) tc.getMinimum().cost + 1;
347             Order filteredOrder = filter.translate(order);
348             Collection trials = trialMap.getValues(tc);
349             if(trials != null){
350                 for(Iterator it = trials.iterator(); it.hasNext(); ){
351                     TrialInstance instance = (TrialInstance) it.next();
352                     instance.recomputeCost(trialColBest);
353                 }
354             }
355             
356             Assert._assert(filteredOrder.numberOfElements() > 1);
357           //  System.out.println("Order: " + order + "\n" + filter + "\nfiltered order: " + filteredOrder);
358             double score = (double) (info.cost + 1) / trialColBest; 
359             TrialInstance instance = TrialInstance.construct(info, filteredOrder, score, trialInstances);
360             if(instance == null){
361                 System.out.println("Failed constructing instance of " + filteredOrder + " with " + filter + " on " + trialInstances);
362                 Assert.UNREACHABLE();
363             }
364             trialMap.add(tc, instance);
365             //System.out.println("Adding new Instance to DataGroup: " + this);
366             trialInstances.add(instance);
367             return true;
368         }
369         
370         public static class VariableTrialDataGroup extends TrialDataGroup{
371             private Collection variables;
372             public VariableTrialDataGroup(Collection variables, TrialInstances instances){
373                 super(instances);
374                 this.variables = variables;
375                 this.filter = new FilterTranslator(variables);
376             }
377           
378             public Collection getVariables(){ return new LinkedList(variables); }
379         }
380         
381         public static class AttribTrialDataGroup extends TrialDataGroup{
382             private Collection attribs;
383             public AttribTrialDataGroup(Collection attribs, TrialInstances instances){
384                super(instances);
385                this.attribs = attribs;
386                this.filter = new FilterTranslator(attribs);
387             }
388         }
389         
390         public static class DomainTrialDataGroup extends TrialDataGroup{
391             private Collection domains;
392             public DomainTrialDataGroup(Collection domains, TrialInstances instances){
393                 super(instances);
394                 this.domains = domains;
395                 this.filter = new FilterTranslator(domains);
396             }
397         }
398         
399     }
400 }