1
2
3
4
5
6
7 package net.sf.bddbddb.order;
8
9 import java.util.Collection;
10 import java.util.Comparator;
11 import java.util.HashMap;
12 import java.util.HashSet;
13 import java.util.Iterator;
14 import java.util.LinkedList;
15 import java.util.List;
16 import java.util.Map;
17 import java.util.Set;
18 import java.io.PrintStream;
19 import jwutil.collections.GenericMultiMap;
20 import jwutil.collections.MultiMap;
21 import jwutil.collections.SortedArraySet;
22 import jwutil.util.Assert;
23 import net.sf.bddbddb.BDDSolver;
24 import net.sf.bddbddb.FindBestDomainOrder;
25 import net.sf.bddbddb.InferenceRule;
26 import weka.classifiers.Classifier;
27 import weka.core.FastVector;
28
29 /***
30 * @author Administrator
31 *
32 * TODO To change the template for this generated type comment go to
33 * Window - Preferences - Java - Code Style - Code Templates
34 */
35 public class TrialDataRepository {
36
37 Collection allTrials;
38 static int TRACE = FindBestDomainOrder.TRACE;
39 Map
40 MultiMap varListeners, attribListeners, domainListeners;
41 static PrintStream out = FindBestDomainOrder.out;
42 BDDSolver solver;
43 public TrialDataRepository(BDDSolver solver){
44 this.solver = solver;
45 varDataMap = new HashMap();
46 varListeners = new GenericMultiMap();
47 attribDataMap = new HashMap();
48 attribListeners = new GenericMultiMap();
49 domainDataMap = new HashMap();
50 domainListeners = new GenericMultiMap();
51 allTrials = new LinkedList();
52 }
53
54 public TrialDataRepository(Collection allTrials, BDDSolver solver){
55 this(solver);
56 this.allTrials = allTrials;
57 }
58
59 public TrialInstances buildVarInstances(InferenceRule ir, List allVars) {
60 FastVector attributes = new FastVector();
61 WekaInterface.addAllPairs(attributes, allVars);
62 attributes.addElement(new weka.core.Attribute("score"));
63 int capacity = 30;
64 OrderTranslator filter = new FilterTranslator(allVars);
65 TrialInstances data = new TrialInstances("Var Ordering Constraints", attributes, capacity);
66 if (allVars.size() <= 1) return data;
67 for (Iterator i = allTrials.iterator(); i.hasNext();) {
68 EpisodeCollection tc2 = (EpisodeCollection) i.next();
69 InferenceRule ir2 = tc2.getRule(solver);
70 if (ir != ir2) continue;
71 addToInstances(data, tc2, filter);
72 }
73 data.setClassIndex(data.numAttributes() - 1);
74 return data;
75 }
76
77 public TrialInstances buildAttribInstances(InferenceRule ir, List allVars) {
78 Collection allAttribs = VarToAttribMap.convert(allVars, ir);
79 if (TRACE > 1) out.println("Attribs: "+allAttribs);
80 FastVector attributes = new FastVector();
81 WekaInterface.addAllPairs(attributes, allAttribs);
82 attributes.addElement(new weka.core.Attribute("score"));
83 int capacity = 30;
84 TrialInstances data = new TrialInstances("Attribute Ordering Constraints", attributes, capacity);
85 if (allAttribs.size() <= 1) return data;
86 for (Iterator i = allTrials.iterator(); i.hasNext();) {
87 EpisodeCollection tc2 = (EpisodeCollection) i.next();
88 InferenceRule ir2 = tc2.getRule(solver);
89 OrderTranslator t = new VarToAttribTranslator(ir2);
90 t = new OrderTranslator.Compose(t, new FilterTranslator(allAttribs));
91 addToInstances(data, tc2, t);
92 }
93 data.setClassIndex(data.numAttributes() - 1);
94 return data;
95 }
96
97 public TrialInstances buildDomainInstances(InferenceRule ir, List allVars) {
98 Collection allDomains = AttribToDomainMap.convert(VarToAttribMap.convert(allVars, ir));
99 if (TRACE > 1) out.println("Domains: "+allDomains);
100 FastVector attributes = new FastVector();
101 WekaInterface.addAllPairs(attributes, allDomains);
102 attributes.addElement(new weka.core.Attribute("score"));
103 int capacity = 30;
104 TrialInstances data = new TrialInstances("Domain Ordering Constraints", attributes, capacity);
105 if (allDomains.size() <= 1) return data;
106 for (Iterator i = allTrials.iterator(); i.hasNext();) {
107 EpisodeCollection tc2 = (EpisodeCollection) i.next();
108 InferenceRule ir2 = tc2.getRule(solver);
109 OrderTranslator t = new VarToAttribTranslator(ir2);
110 t = new OrderTranslator.Compose(t, AttribToDomainTranslator.INSTANCE);
111 t = new OrderTranslator.Compose(t, new FilterTranslator(allDomains));
112 addToInstances(data, tc2, t);
113 }
114 data.setClassIndex(data.numAttributes() - 1);
115 return data;
116 }
117
118 public static void addToInstances(TrialInstances data, EpisodeCollection tc, OrderTranslator t) {
119 if (tc.getNumTrials() == 0) return;
120 double best;
121 if (tc.getMinimum().isMax()) best = 1;
122 else best = (double) tc.getMinimum().cost + 1;
123 for (Iterator j = tc.trials.values().iterator(); j.hasNext();) {
124 TrialInfo ti = (TrialInfo) j.next();
125 double score = (double) (ti.cost + 1) / best;
126 Order o = t == null ? ti.order : t.translate(ti.order);
127 if (o.numberOfElements() <= 1) continue;
128 TrialInstance tinst = TrialInstance.construct(ti, o, score, data);
129 if (tinst != null) data.add(tinst);
130 }
131 }
132
133
134
135 public TrialDataGroup getVariableDataGroup(InferenceRule rule, List variables){
136 TrialDataGroup dataGroup = (TrialDataGroup) varDataMap.get(variables);
137 if(dataGroup == null){
138 dataGroup = new TrialDataGroup.VariableTrialDataGroup(variables, buildVarInstances(rule, variables));
139 varDataMap.put(variables, dataGroup);
140 Collection pairs = WekaInterface.generateAllPairs(variables);
141 for(Iterator it = pairs.iterator(); it.hasNext(); ){
142 varListeners.add(it.next(), dataGroup);
143 }
144 }
145 return dataGroup;
146 }
147
148
149 public TrialDataGroup getAttribDataGroup(InferenceRule rule, List variables){
150 Set attribs = new HashSet(VarToAttribMap.convert(variables, rule));
151 TrialDataGroup dataGroup = (TrialDataGroup) attribDataMap.get(attribs);
152 if(dataGroup == null){
153 dataGroup = new TrialDataGroup.AttribTrialDataGroup(attribs, buildAttribInstances(rule, variables));
154 attribDataMap.put(attribs, dataGroup);
155 Collection pairs = WekaInterface.generateAllPairs(attribs);
156 for(Iterator it = pairs.iterator(); it.hasNext(); ){
157 attribListeners.add(it.next(), dataGroup);
158 }
159 }
160 return dataGroup;
161 }
162
163 public TrialDataGroup getDomainDataGroup(InferenceRule rule, List variables){
164 List domains = new LinkedList(AttribToDomainMap.convert(VarToAttribMap.convert(variables, rule)));
165 Set domainSet = new HashSet(domains);
166 TrialDataGroup dataGroup = (TrialDataGroup) domainDataMap.get(domains);
167 if(dataGroup == null){
168 dataGroup = new TrialDataGroup.DomainTrialDataGroup(domains, buildDomainInstances(rule, variables) );
169 domainDataMap.put(domains, dataGroup);
170 Collection pairs = WekaInterface.generateAllPairs(domains);
171 for(Iterator it = pairs.iterator(); it.hasNext(); ){
172 domainListeners.add(it.next(), dataGroup);
173 }
174 }
175 return dataGroup;
176 }
177
178 public boolean addTrial(InferenceRule rule, List variables, TrialInfo info){
179 Order o_v = info.order;
180 EpisodeCollection tc = info.getCollection();
181
182
183 boolean changed = false;
184 Collection varPairs = WekaInterface.generateAllPairs(variables);
185 Collection notified = new HashSet();
186 for(Iterator it = varPairs.iterator(); it.hasNext(); ){
187 Collection listeners = varListeners.getValues(it.next());
188 Assert._assert(listeners != null);
189 for(Iterator jt = listeners.iterator(); jt.hasNext();){
190 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
191 if(!notified.contains(dataGroup)){
192 changed |= dataGroup.update(o_v, info,tc);
193 notified.add(dataGroup);
194 }
195 }
196 }
197
198 OrderTranslator translator = new VarToAttribTranslator(rule);
199 Order o_a = translator.translate(o_v);
200 Collection attribs = VarToAttribMap.convert(variables, rule);
201 Collection attribPairs = WekaInterface.generateAllPairs(attribs);
202 for(Iterator it = attribPairs.iterator(); it.hasNext(); ){
203 Collection listeners = attribListeners.getValues(it.next());
204 Assert._assert(listeners != null);
205 for(Iterator jt = listeners.iterator(); jt.hasNext();){
206 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
207 if(!notified.contains(dataGroup)){
208 changed |= dataGroup.update(o_a, info,tc);
209 notified.add(dataGroup);
210 }
211 }
212 }
213 Order o_d = AttribToDomainTranslator.INSTANCE.translate(o_a);
214 Collection domainPairs = WekaInterface.generateAllPairs(AttribToDomainMap.convert(attribs));
215 for(Iterator it = domainPairs.iterator(); it.hasNext(); ){
216 Collection domListeners = domainListeners.getValues(it.next());
217 Assert._assert(domListeners != null);
218 for(Iterator jt = domListeners.iterator(); jt.hasNext(); ){
219 TrialDataGroup dataGroup = (TrialDataGroup) jt.next();
220 if(!notified.contains(dataGroup)){
221 changed |= dataGroup.update(o_d, info, tc);
222 notified.add(dataGroup);
223 }
224 }
225 }
226
227 return changed;
228 }
229
230 public TrialDataRepository reduceByNumTrials(int numTrials){
231 Collection newAllTrials = new LinkedList();
232 numTrials = Math.min(allTrials.size(), numTrials);
233 SortedArraySet sortedTrials = (SortedArraySet) SortedArraySet.FACTORY.makeSet(
234 new Comparator(){
235 public int compare(Object o1, Object o2) {
236 TrialInfo t1 = (TrialInfo) o1;
237 TrialInfo t2 = (TrialInfo) o2;
238 return FindBestDomainOrder.signum(t1.timestamp - t2.timestamp);
239 }
240 });
241 sortedTrials.addAll(allTrials);
242 newAllTrials.addAll(sortedTrials.subList(0, numTrials - 1));
243 return new TrialDataRepository(newAllTrials, this.solver);
244 }
245
246
247
248 public abstract static class TrialDataGroup{
249
250 public static String CLASSIFIER = "net.sf.bddbddb.order.MyId3";
251 private TrialInstances trialInstances;
252 private TrialInstances trialInstancesCopy;
253 private Discretization discretization;
254 private double discretizeParam = 0;
255 private double thresholdParam = 0;
256 private MultiMap
257 private Classifier classifier;
258 private double infoSinceClassRebuild, infoSinceDiscRebuild, infoSinceInstances;
259 private double infoThreshold;
260 protected FilterTranslator filter;
261 protected TrialDataGroup(TrialInstances instances){
262 trialInstances = instances;
263 discretizeParam = -1;
264 thresholdParam = -1;
265 trialMap = new GenericMultiMap();
266 }
267
268 /***
269 * @return Returns the classifier.
270 */
271 public Classifier classify() {
272 if(discretizeParam < 0 && thresholdParam < 0)
273 return null;
274 Assert._assert(discretizeParam < 0 ^ thresholdParam < 0);
275
276 if(discretizeParam > 0)
277 discretize(discretizeParam);
278 else
279 threshold(thresholdParam);
280
281 TrialInstances instances = getTrialInstances();
282 classifier = instances.numInstances() > 0 ? WekaInterface.buildClassifier(CLASSIFIER, instances) : null;
283 return classifier;
284 }
285
286 public Classifier getClassifier(){
287 return classifier;
288 }
289 public void setDiscretizeParam(double discretize){
290 discretizeParam = discretize;
291 thresholdParam = -1;
292 }
293
294 public void setThresholdParam(double thresholdParam){
295 this.thresholdParam = thresholdParam;
296 discretizeParam = -1;
297 }
298 /***
299 * @return Returns the discretization.
300 */
301 public Discretization discretize(double discretizeFact) {
302 if((discretizeFact != discretizeParam) || (infoSinceDiscRebuild > infoThreshold)){
303 setDiscretizeParam(discretizeFact);
304 discretization = getTrialInstances().discretize(discretizeParam);
305 infoSinceDiscRebuild = 0;
306 }
307 return discretization;
308 }
309
310 public Discretization getDiscretization(){
311
312 Assert._assert(discretizeParam != -1 || thresholdParam != -1);
313 if(discretizeParam != -1) return discretize(discretizeParam);
314 return threshold(thresholdParam);
315
316 }
317
318 public Discretization threshold(double threshold){
319 if((threshold != thresholdParam) || (infoSinceDiscRebuild > infoThreshold)){
320 setThresholdParam(threshold);
321 discretization = getTrialInstances().threshold(thresholdParam);
322 infoSinceDiscRebuild = 0;
323 }
324 return discretization;
325 }
326
327 /***
328 * @return Returns the instances.
329 */
330 public TrialInstances getTrialInstances() {
331 if(trialInstancesCopy == null || infoSinceInstances > infoThreshold){
332 trialInstancesCopy = trialInstances.copy();
333 infoSinceInstances = 0;
334 }
335 return trialInstancesCopy;
336 }
337 public void forceRebuildNext(){
338 infoSinceClassRebuild = Double.POSITIVE_INFINITY;
339 infoSinceDiscRebuild = Double.POSITIVE_INFINITY;
340 infoSinceInstances = Double.POSITIVE_INFINITY;
341 }
342 public boolean update(Order order, TrialInfo info, EpisodeCollection tc){
343 forceRebuildNext();
344 double trialColBest;
345 if (tc.getMinimum().isMax()) trialColBest = 1;
346 else trialColBest = (double) tc.getMinimum().cost + 1;
347 Order filteredOrder = filter.translate(order);
348 Collection trials = trialMap.getValues(tc);
349 if(trials != null){
350 for(Iterator it = trials.iterator(); it.hasNext(); ){
351 TrialInstance instance = (TrialInstance) it.next();
352 instance.recomputeCost(trialColBest);
353 }
354 }
355
356 Assert._assert(filteredOrder.numberOfElements() > 1);
357
358 double score = (double) (info.cost + 1) / trialColBest;
359 TrialInstance instance = TrialInstance.construct(info, filteredOrder, score, trialInstances);
360 if(instance == null){
361 System.out.println("Failed constructing instance of " + filteredOrder + " with " + filter + " on " + trialInstances);
362 Assert.UNREACHABLE();
363 }
364 trialMap.add(tc, instance);
365
366 trialInstances.add(instance);
367 return true;
368 }
369
370 public static class VariableTrialDataGroup extends TrialDataGroup{
371 private Collection variables;
372 public VariableTrialDataGroup(Collection variables, TrialInstances instances){
373 super(instances);
374 this.variables = variables;
375 this.filter = new FilterTranslator(variables);
376 }
377
378 public Collection getVariables(){ return new LinkedList(variables); }
379 }
380
381 public static class AttribTrialDataGroup extends TrialDataGroup{
382 private Collection attribs;
383 public AttribTrialDataGroup(Collection attribs, TrialInstances instances){
384 super(instances);
385 this.attribs = attribs;
386 this.filter = new FilterTranslator(attribs);
387 }
388 }
389
390 public static class DomainTrialDataGroup extends TrialDataGroup{
391 private Collection domains;
392 public DomainTrialDataGroup(Collection domains, TrialInstances instances){
393 super(instances);
394 this.domains = domains;
395 this.filter = new FilterTranslator(domains);
396 }
397 }
398
399 }
400 }