View Javadoc

1   // WekaInterface.java, created Oct 31, 2004 1:17:46 AM by joewhaley
2   // Copyright (C) 2004 John Whaley <jwhaley@alum.mit.edu>
3   // Licensed under the terms of the GNU LGPL; see COPYING for details.
4   package net.sf.bddbddb.order;
5   
6   import java.util.Collection;
7   import java.util.Enumeration;
8   import java.util.HashSet;
9   import java.util.Iterator;
10  import java.util.LinkedList;
11  
12  import jwutil.collections.Pair;
13  import jwutil.util.Assert;
14  import net.sf.bddbddb.Attribute;
15  import net.sf.bddbddb.FindBestDomainOrder;
16  import net.sf.bddbddb.InferenceRule;
17  import net.sf.bddbddb.Variable;
18  import net.sf.bddbddb.order.OrderConstraint.AfterConstraint;
19  import net.sf.bddbddb.order.OrderConstraint.BeforeConstraint;
20  import net.sf.bddbddb.order.OrderConstraint.InterleaveConstraint;
21  import weka.classifiers.Classifier;
22  import weka.core.FastVector;
23  import weka.core.Instance;
24  import weka.core.Instances;
25  
26  /***
27   * WekaInterface
28   * 
29   * @author jwhaley
30   * @version $Id: WekaInterface.java 531 2005-04-29 06:39:10Z joewhaley $
31   */
32  public abstract class WekaInterface {
33      
34      public static OrderAttribute makeOrderAttribute(OrderConstraint c) {
35          return new OrderAttribute(c.a, c.b);
36      }
37      
38      public static OrderAttribute makeOrderAttribute(Object a, Object b) {
39          if (OrderConstraint.compare(a, b)) {
40              return new OrderAttribute(a, b);
41          } else {
42              return new OrderAttribute(b, a);
43          }
44      }
45      
46      public static int INTERLEAVE = 1;
47      public static int getType(OrderConstraint oc) {
48          if (oc instanceof BeforeConstraint) return 0;
49          else if (oc instanceof InterleaveConstraint) return INTERLEAVE;
50          else if (oc instanceof AfterConstraint) return 2;
51          else return -1;
52      }
53      
54      public static class OrderAttribute extends weka.core.Attribute {
55          /***
56           * Version ID for serialization.
57           */
58          private static final long serialVersionUID = 3257291339690555447L;
59  
60          Object a, b;
61          
62          static FastVector my_nominal_values = new FastVector(3);
63          static {
64              my_nominal_values.addElement("<");
65              my_nominal_values.addElement("~");
66              my_nominal_values.addElement(">"); 
67          }
68          
69          private OrderAttribute(Object a, Object b) {
70              super(a+","+b, my_nominal_values);
71              this.a = a;
72              this.b = b;
73          }
74          
75          public OrderConstraint getConstraint(int k) {
76              switch (k) {
77                  case 0: return OrderConstraint.makePrecedenceConstraint(a, b);
78                  case 1: return OrderConstraint.makeInterleaveConstraint(a, b);
79                  case 2: return OrderConstraint.makePrecedenceConstraint(b, a);
80                  default: return null;
81              }
82          }
83          
84          public OrderConstraint getConstraint(weka.core.Instance i) {
85              int k = (int) i.value(this);
86              return getConstraint(k);
87          }
88  
89      }
90      
91      public static void addAllPairs(FastVector v, Collection c) {
92          Collection pairs = new HashSet();
93          for (Iterator i = c.iterator(); i.hasNext(); ) {
94              Object a = i.next();
95              Iterator j = c.iterator();
96              while (j.hasNext() && j.next() != a) ;
97              while (j.hasNext()) {
98                  Object b = j.next();
99                  UnorderedPair pair = new UnorderedPair(a,b);
100                 if(pairs.contains(pair)) continue;
101                 OrderAttribute oa = makeOrderAttribute(a, b);
102                 v.addElement(oa);
103                 pairs.add(pair);
104             }
105         }
106    //     System.out.println(new HashSet(c) + " Size: " + v.size());
107         
108     }
109     
110     public static Collection /*UnorderedPair*/ generateAllPairs(Collection c){
111         Collection pairs = new HashSet();
112         for (Iterator i = c.iterator(); i.hasNext(); ) {
113             Object a = i.next();
114             Iterator j = c.iterator();
115             while (j.hasNext() && j.next() != a) ;
116             while (j.hasNext()) {
117                 Object b = j.next();
118                 Pair pair = new UnorderedPair(a, b);
119                 pairs.add(pair);
120             }
121         }
122         return pairs;
123     }
124 
125     public static FastVector constructVarAttributes(Collection vars) {
126         FastVector v = new FastVector();
127         addAllPairs(v, vars);
128         return v;
129     }
130     
131     public static FastVector constructAttribAttributes(InferenceRule ir, Collection vars) {
132         Collection attribs = new LinkedList();
133         for (Iterator i = vars.iterator(); i.hasNext(); ) {
134             Variable v = (Variable) i.next();
135             Attribute a = ir.getAttribute(v);
136             if (a != null) attribs.add(a);
137         }
138         FastVector v = new FastVector();
139         addAllPairs(v, attribs);
140         return v;
141     }
142     
143     public static FastVector constructDomainAttributes(InferenceRule ir, Collection vars) {
144         Collection domains = new LinkedList();
145         for (Iterator i = vars.iterator(); i.hasNext(); ) {
146             Variable v = (Variable) i.next();
147             Attribute a = ir.getAttribute(v);
148             if (a != null) domains.add(a.getDomain());
149         }
150         FastVector v = new FastVector();
151         addAllPairs(v, domains);
152         return v;
153     }
154     
155     public static weka.core.Attribute makeBucketAttribute(int numClusters) {
156         FastVector clusterValues = new FastVector(numClusters);
157         for (int i = 0; i < numClusters; ++i)
158             clusterValues.addElement(Integer.toString(i));
159         return new weka.core.Attribute("costBucket", clusterValues);
160     }
161 
162     public static Classifier buildClassifier(String cClassName, Instances data) {
163         // Build the classifier.
164         Classifier classifier = null;
165         try {
166             long time = System.currentTimeMillis();
167             classifier = (Classifier) Class.forName(cClassName).newInstance();
168             classifier.buildClassifier(data);
169             if (FindBestDomainOrder.TRACE > 1) System.out.println("Classifier "+cClassName+" took "+(System.currentTimeMillis()-time)+" ms to build.");
170             if (FindBestDomainOrder.TRACE > 2) System.out.println(classifier);
171         } catch (Exception x) {
172             FindBestDomainOrder.out.println(cClassName + ": " + x.getLocalizedMessage());
173             return null;
174         }
175         return classifier;
176     }
177 
178     public static double leaveOneOutCV(Instances data, String cClassName) {
179         return WekaInterface.cvError(data.numInstances(), data, cClassName);
180     }
181 
182     public static double cvError(int numFolds, Instances data0, String cClassName) {
183         if (data0.numInstances() < numFolds)
184             return Double.NaN; //more folds than elements
185         if (numFolds == 0)
186             return Double.NaN; // no folds
187         if (data0.numInstances() == 0)
188             return 0; //no instances
189     
190         Instances data = new Instances(data0);
191         //data.randomize(new Random(System.currentTimeMillis()));
192         data.stratify(numFolds);
193         Assert._assert(data.classAttribute() != null);
194         double[] estimates = new double[numFolds];
195         for (int i = 0; i < numFolds; ++i) {
196             Instances trainData = data.trainCV(numFolds, i);
197             Assert._assert(trainData.classAttribute() != null);
198             Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances.");
199     
200             Instances testData = data.testCV(numFolds, i);
201             Assert._assert(testData.classAttribute() != null);
202             Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances.");
203     
204             int temp = FindBestDomainOrder.TRACE;
205             FindBestDomainOrder.TRACE = 0;
206             Classifier classifier = buildClassifier(cClassName, trainData);
207             FindBestDomainOrder.TRACE = temp;
208             int count = testData.numInstances();
209             double loss = 0;
210             double sum = 0;
211             for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) {
212                 Instance instance = (Instance) e.nextElement();
213                 Assert._assert(instance != null);
214                 Assert._assert(instance.classAttribute() != null && instance.classAttribute() == trainData.classAttribute());
215                 try {
216                     double testClass = classifier.classifyInstance(instance);
217                     double weight = instance.weight();
218                     if (testClass != instance.classValue())
219                         loss += weight;
220                     sum += weight;
221                 } catch (Exception ex) {
222                     FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex);
223                 }
224             }
225             estimates[i] = 1 - loss / sum;
226         }
227         double average = 0;
228         for (int i = 0; i < numFolds; ++i)
229             average += estimates[i];
230     
231         return average / numFolds;
232     }
233 
234     public static TrialInstances binarize(double classValue, TrialInstances data) {
235         TrialInstances newInstances = data.infoClone();
236         weka.core.Attribute newAttr = makeBucketAttribute(2);
237         TrialInstances.setIndex(newAttr, newInstances.classIndex());
238         newInstances.setClass(newAttr);
239         newInstances.setClassIndex(data.classIndex());
240         for (Enumeration e = data.enumerateInstances(); e.hasMoreElements();) {
241             TrialInstance instance = (TrialInstance) e.nextElement();
242             TrialInstance newInstance = TrialInstance.cloneInstance(instance);
243             newInstance.setDataset(newInstances);
244             if (instance.classValue() <= classValue) {
245                 newInstance.setClassValue(0);
246             } else {
247                 newInstance.setClassValue(1);
248             }
249             newInstances.add(newInstance);
250         }
251         return newInstances;
252     }
253 
254     public static class OrderInstance extends Instance {
255         
256         /***
257          * Version ID for serialization.
258          */
259         private static final long serialVersionUID = 3258412811553093939L;
260 
261         public static OrderInstance construct(Order o, Instances dataSet) {
262             return construct(o, dataSet, 1);
263         }
264         
265         public static OrderInstance construct(Order o, Instances dataSet, double weight) {
266             double[] d = new double[dataSet.numAttributes()];
267             for (int i = 0; i < d.length; ++i) {
268                 d[i] = Instance.missingValue();
269             }
270             for (Iterator i = o.getConstraints().iterator(); i.hasNext(); ) {
271                 OrderConstraint oc = (OrderConstraint) i.next();
272                 // TODO: use a map from Pair to int instead of building String and doing linear search.
273             
274                 String cName = oc.getFirst()+","+oc.getSecond();
275                 OrderAttribute oa = (OrderAttribute) dataSet.attribute(cName);
276                 if (oa != null) {
277                     
278                     if(oc.getFirst().equals(oc.getSecond()) && d[oa.index()] == INTERLEAVE) 
279                         continue;
280                     /* TODO should only one type of constraint for
281                      * when first == second and they are not interleaved
282                      */  
283                     d[oa.index()] = getType(oc);
284                 } else {
285                     
286                     System.out.println("Warning: while building OrderInstance for " + o + " couldn't find constraint "+oc+" in data set");
287                     System.out.println("dataset\n: " + dataSet);
288                     Assert.UNREACHABLE();
289                 }
290             }
291             return new OrderInstance(weight, d, o);
292         }
293         
294         protected Order o;
295         
296         protected OrderInstance(double w, double[] d, Order o) {
297             super(w, d);
298             this.o = o;
299         }
300         protected OrderInstance(OrderInstance that) {
301             super(that);
302             this.o = that.o;
303         }
304         
305         public Object copy() {
306             return new OrderInstance(this);
307         }
308         
309         public Order getOrder() {
310             return o;
311         }
312         
313     }
314 
315 }