1
2
3
4 package net.sf.bddbddb.order;
5
6 import java.util.Collection;
7 import java.util.Enumeration;
8 import java.util.HashSet;
9 import java.util.Iterator;
10 import java.util.LinkedList;
11
12 import jwutil.collections.Pair;
13 import jwutil.util.Assert;
14 import net.sf.bddbddb.Attribute;
15 import net.sf.bddbddb.FindBestDomainOrder;
16 import net.sf.bddbddb.InferenceRule;
17 import net.sf.bddbddb.Variable;
18 import net.sf.bddbddb.order.OrderConstraint.AfterConstraint;
19 import net.sf.bddbddb.order.OrderConstraint.BeforeConstraint;
20 import net.sf.bddbddb.order.OrderConstraint.InterleaveConstraint;
21 import weka.classifiers.Classifier;
22 import weka.core.FastVector;
23 import weka.core.Instance;
24 import weka.core.Instances;
25
26 /***
27 * WekaInterface
28 *
29 * @author jwhaley
30 * @version $Id: WekaInterface.java 531 2005-04-29 06:39:10Z joewhaley $
31 */
32 public abstract class WekaInterface {
33
34 public static OrderAttribute makeOrderAttribute(OrderConstraint c) {
35 return new OrderAttribute(c.a, c.b);
36 }
37
38 public static OrderAttribute makeOrderAttribute(Object a, Object b) {
39 if (OrderConstraint.compare(a, b)) {
40 return new OrderAttribute(a, b);
41 } else {
42 return new OrderAttribute(b, a);
43 }
44 }
45
46 public static int INTERLEAVE = 1;
47 public static int getType(OrderConstraint oc) {
48 if (oc instanceof BeforeConstraint) return 0;
49 else if (oc instanceof InterleaveConstraint) return INTERLEAVE;
50 else if (oc instanceof AfterConstraint) return 2;
51 else return -1;
52 }
53
54 public static class OrderAttribute extends weka.core.Attribute {
55 /***
56 * Version ID for serialization.
57 */
58 private static final long serialVersionUID = 3257291339690555447L;
59
60 Object a, b;
61
62 static FastVector my_nominal_values = new FastVector(3);
63 static {
64 my_nominal_values.addElement("<");
65 my_nominal_values.addElement("~");
66 my_nominal_values.addElement(">");
67 }
68
69 private OrderAttribute(Object a, Object b) {
70 super(a+","+b, my_nominal_values);
71 this.a = a;
72 this.b = b;
73 }
74
75 public OrderConstraint getConstraint(int k) {
76 switch (k) {
77 case 0: return OrderConstraint.makePrecedenceConstraint(a, b);
78 case 1: return OrderConstraint.makeInterleaveConstraint(a, b);
79 case 2: return OrderConstraint.makePrecedenceConstraint(b, a);
80 default: return null;
81 }
82 }
83
84 public OrderConstraint getConstraint(weka.core.Instance i) {
85 int k = (int) i.value(this);
86 return getConstraint(k);
87 }
88
89 }
90
91 public static void addAllPairs(FastVector v, Collection c) {
92 Collection pairs = new HashSet();
93 for (Iterator i = c.iterator(); i.hasNext(); ) {
94 Object a = i.next();
95 Iterator j = c.iterator();
96 while (j.hasNext() && j.next() != a) ;
97 while (j.hasNext()) {
98 Object b = j.next();
99 UnorderedPair pair = new UnorderedPair(a,b);
100 if(pairs.contains(pair)) continue;
101 OrderAttribute oa = makeOrderAttribute(a, b);
102 v.addElement(oa);
103 pairs.add(pair);
104 }
105 }
106
107
108 }
109
110 public static Collection
111 Collection pairs = new HashSet();
112 for (Iterator i = c.iterator(); i.hasNext(); ) {
113 Object a = i.next();
114 Iterator j = c.iterator();
115 while (j.hasNext() && j.next() != a) ;
116 while (j.hasNext()) {
117 Object b = j.next();
118 Pair pair = new UnorderedPair(a, b);
119 pairs.add(pair);
120 }
121 }
122 return pairs;
123 }
124
125 public static FastVector constructVarAttributes(Collection vars) {
126 FastVector v = new FastVector();
127 addAllPairs(v, vars);
128 return v;
129 }
130
131 public static FastVector constructAttribAttributes(InferenceRule ir, Collection vars) {
132 Collection attribs = new LinkedList();
133 for (Iterator i = vars.iterator(); i.hasNext(); ) {
134 Variable v = (Variable) i.next();
135 Attribute a = ir.getAttribute(v);
136 if (a != null) attribs.add(a);
137 }
138 FastVector v = new FastVector();
139 addAllPairs(v, attribs);
140 return v;
141 }
142
143 public static FastVector constructDomainAttributes(InferenceRule ir, Collection vars) {
144 Collection domains = new LinkedList();
145 for (Iterator i = vars.iterator(); i.hasNext(); ) {
146 Variable v = (Variable) i.next();
147 Attribute a = ir.getAttribute(v);
148 if (a != null) domains.add(a.getDomain());
149 }
150 FastVector v = new FastVector();
151 addAllPairs(v, domains);
152 return v;
153 }
154
155 public static weka.core.Attribute makeBucketAttribute(int numClusters) {
156 FastVector clusterValues = new FastVector(numClusters);
157 for (int i = 0; i < numClusters; ++i)
158 clusterValues.addElement(Integer.toString(i));
159 return new weka.core.Attribute("costBucket", clusterValues);
160 }
161
162 public static Classifier buildClassifier(String cClassName, Instances data) {
163
164 Classifier classifier = null;
165 try {
166 long time = System.currentTimeMillis();
167 classifier = (Classifier) Class.forName(cClassName).newInstance();
168 classifier.buildClassifier(data);
169 if (FindBestDomainOrder.TRACE > 1) System.out.println("Classifier "+cClassName+" took "+(System.currentTimeMillis()-time)+" ms to build.");
170 if (FindBestDomainOrder.TRACE > 2) System.out.println(classifier);
171 } catch (Exception x) {
172 FindBestDomainOrder.out.println(cClassName + ": " + x.getLocalizedMessage());
173 return null;
174 }
175 return classifier;
176 }
177
178 public static double leaveOneOutCV(Instances data, String cClassName) {
179 return WekaInterface.cvError(data.numInstances(), data, cClassName);
180 }
181
182 public static double cvError(int numFolds, Instances data0, String cClassName) {
183 if (data0.numInstances() < numFolds)
184 return Double.NaN;
185 if (numFolds == 0)
186 return Double.NaN;
187 if (data0.numInstances() == 0)
188 return 0;
189
190 Instances data = new Instances(data0);
191
192 data.stratify(numFolds);
193 Assert._assert(data.classAttribute() != null);
194 double[] estimates = new double[numFolds];
195 for (int i = 0; i < numFolds; ++i) {
196 Instances trainData = data.trainCV(numFolds, i);
197 Assert._assert(trainData.classAttribute() != null);
198 Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances.");
199
200 Instances testData = data.testCV(numFolds, i);
201 Assert._assert(testData.classAttribute() != null);
202 Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances.");
203
204 int temp = FindBestDomainOrder.TRACE;
205 FindBestDomainOrder.TRACE = 0;
206 Classifier classifier = buildClassifier(cClassName, trainData);
207 FindBestDomainOrder.TRACE = temp;
208 int count = testData.numInstances();
209 double loss = 0;
210 double sum = 0;
211 for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) {
212 Instance instance = (Instance) e.nextElement();
213 Assert._assert(instance != null);
214 Assert._assert(instance.classAttribute() != null && instance.classAttribute() == trainData.classAttribute());
215 try {
216 double testClass = classifier.classifyInstance(instance);
217 double weight = instance.weight();
218 if (testClass != instance.classValue())
219 loss += weight;
220 sum += weight;
221 } catch (Exception ex) {
222 FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex);
223 }
224 }
225 estimates[i] = 1 - loss / sum;
226 }
227 double average = 0;
228 for (int i = 0; i < numFolds; ++i)
229 average += estimates[i];
230
231 return average / numFolds;
232 }
233
234 public static TrialInstances binarize(double classValue, TrialInstances data) {
235 TrialInstances newInstances = data.infoClone();
236 weka.core.Attribute newAttr = makeBucketAttribute(2);
237 TrialInstances.setIndex(newAttr, newInstances.classIndex());
238 newInstances.setClass(newAttr);
239 newInstances.setClassIndex(data.classIndex());
240 for (Enumeration e = data.enumerateInstances(); e.hasMoreElements();) {
241 TrialInstance instance = (TrialInstance) e.nextElement();
242 TrialInstance newInstance = TrialInstance.cloneInstance(instance);
243 newInstance.setDataset(newInstances);
244 if (instance.classValue() <= classValue) {
245 newInstance.setClassValue(0);
246 } else {
247 newInstance.setClassValue(1);
248 }
249 newInstances.add(newInstance);
250 }
251 return newInstances;
252 }
253
254 public static class OrderInstance extends Instance {
255
256 /***
257 * Version ID for serialization.
258 */
259 private static final long serialVersionUID = 3258412811553093939L;
260
261 public static OrderInstance construct(Order o, Instances dataSet) {
262 return construct(o, dataSet, 1);
263 }
264
265 public static OrderInstance construct(Order o, Instances dataSet, double weight) {
266 double[] d = new double[dataSet.numAttributes()];
267 for (int i = 0; i < d.length; ++i) {
268 d[i] = Instance.missingValue();
269 }
270 for (Iterator i = o.getConstraints().iterator(); i.hasNext(); ) {
271 OrderConstraint oc = (OrderConstraint) i.next();
272
273
274 String cName = oc.getFirst()+","+oc.getSecond();
275 OrderAttribute oa = (OrderAttribute) dataSet.attribute(cName);
276 if (oa != null) {
277
278 if(oc.getFirst().equals(oc.getSecond()) && d[oa.index()] == INTERLEAVE)
279 continue;
280
281
282
283 d[oa.index()] = getType(oc);
284 } else {
285
286 System.out.println("Warning: while building OrderInstance for " + o + " couldn't find constraint "+oc+" in data set");
287 System.out.println("dataset\n: " + dataSet);
288 Assert.UNREACHABLE();
289 }
290 }
291 return new OrderInstance(weight, d, o);
292 }
293
294 protected Order o;
295
296 protected OrderInstance(double w, double[] d, Order o) {
297 super(w, d);
298 this.o = o;
299 }
300 protected OrderInstance(OrderInstance that) {
301 super(that);
302 this.o = that.o;
303 }
304
305 public Object copy() {
306 return new OrderInstance(this);
307 }
308
309 public Order getOrder() {
310 return o;
311 }
312
313 }
314
315 }