1
2
3
4 package net.sf.bddbddb.order;
5
6 import java.util.Arrays;
7 import java.util.Enumeration;
8 import java.util.Iterator;
9 import java.util.LinkedList;
10 import java.util.List;
11
12 import jwutil.util.Assert;
13 import net.sf.bddbddb.FindBestDomainOrder;
14 import weka.classifiers.Classifier;
15 import weka.classifiers.Evaluation;
16 import weka.classifiers.trees.Id3;
17 import weka.core.Attribute;
18 import weka.core.Instance;
19 import weka.core.Instances;
20 import weka.core.NoSupportForMissingValuesException;
21 import weka.core.UnsupportedAttributeTypeException;
22 import weka.core.UnsupportedClassTypeException;
23 import weka.core.Utils;
24
25 /***
26 * Class implementing an Id3 decision tree classifier. This version differs from
27 * the weka one in that it supports missing attributes.
28 *
29 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
30 * @author John Whaley
31 * @version $Revision: 531 $
32 */
33 public class MyId3 extends Classifier {
34 /***
35 * Version ID for serialization.
36 */
37 private static final long serialVersionUID = 3258129154733322289L;
38
39 /*** The node's successors. */
40 private MyId3[] m_Successors;
41 /*** Attribute used for splitting. */
42 private Attribute m_Attribute;
43 /*** Class value if node is leaf. */
44 private double m_ClassValue;
45 /*** Class distribution if node is leaf. */
46 private double[] m_Distribution;
47 /*** Class attribute of dataset. */
48 private Attribute m_ClassAttribute;
49
50 public boolean getAttribCombos(Instances i, double cv) {
51 List r = getAttribCombos(i.numAttributes(), cv);
52 if (r == null) return false;
53 for (Iterator ii = r.iterator(); ii.hasNext(); ) {
54 double[] d = (double[]) ii.next();
55 i.add(new Instance(1., d));
56 }
57 return true;
58 }
59
60 public List getAttribCombos(int nAttribs, double cv) {
61 if (m_Attribute == null) {
62 if (FindBestDomainOrder.compare(m_ClassValue, cv) == 0) {
63 List result = new LinkedList();
64 double[] i = new double[nAttribs];
65 Arrays.fill(i, Double.NaN);
66 result.add(i);
67 return result;
68 } else {
69 return null;
70 }
71 } else {
72 List result = new LinkedList();
73 for (int i = 0; i < m_Successors.length; ++i) {
74 List c = m_Successors[i].getAttribCombos(nAttribs, cv);
75 if (c != null) {
76 int index = m_Attribute.index();
77 for (Iterator j = c.iterator(); j.hasNext(); ) {
78 double[] d = (double[]) j.next();
79 d[index] = i;
80 }
81 result.addAll(c);
82 }
83 }
84 if (result.isEmpty()) return null;
85 else return result;
86 }
87 }
88
89 /***
90 * Returns a string describing the classifier.
91 *
92 * @return a description suitable for the GUI.
93 */
94 public String globalInfo() {
95 return "Class for constructing an unpruned decision tree based on the ID3 "
96 + "algorithm. Can only deal with nominal attributes. "
97 + "Empty leaves may result in unclassified instances. For more "
98 + "information see: \n\n" + " R. Quinlan (1986). \"Induction of decision "
99 + "trees\". Machine Learning. Vol.1, No.1, pp. 81-106";
100 }
101
102 /***
103 * Builds Id3 decision tree classifier.
104 *
105 * @param data
106 * the training data
107 * @exception Exception
108 * if classifier can't be built successfully
109 */
110 public void buildClassifier(Instances data) throws Exception {
111 if (!data.classAttribute().isNominal()) {
112 throw new UnsupportedClassTypeException("Id3: nominal class, please.");
113 }
114 Enumeration enumAtt = data.enumerateAttributes();
115 while (enumAtt.hasMoreElements()) {
116 if (!((Attribute) enumAtt.nextElement()).isNominal()) {
117 throw new UnsupportedAttributeTypeException("Id3: only nominal "
118 + "attributes, please.");
119 }
120 }
121 data = new Instances(data);
122 data.deleteWithMissingClass();
123 makeTree(data);
124 }
125
126 /***
127 * Method for building an Id3 tree.
128 *
129 * @param data
130 * the training data
131 * @exception Exception
132 * if decision tree can't be built successfully
133 */
134 private void makeTree(Instances data) throws Exception {
135
136 if (data.numInstances() == 0) {
137 m_Attribute = null;
138 m_ClassValue = Instance.missingValue();
139 m_Distribution = new double[data.numClasses()];
140 double sum = 0;
141 laplaceSmooth(m_Distribution, sum, data.numClasses());
142 return;
143 }
144
145 double[] infoGains = new double[data.numAttributes()];
146 Enumeration attEnum = data.enumerateAttributes();
147 while (attEnum.hasMoreElements()) {
148 Attribute att = (Attribute) attEnum.nextElement();
149 infoGains[att.index()] = computeInfoGain(data, att);
150 }
151 m_Attribute = data.attribute(Utils.maxIndex(infoGains));
152 boolean makeLeaf;
153 makeLeaf = Utils.eq(infoGains[m_Attribute.index()], 0);
154 Instances[] splitData = null;
155 if (!makeLeaf) {
156 splitData = splitData(data, m_Attribute);
157 for (int i = 0; i < splitData.length; ++i) {
158 if (splitData[i].numInstances() == data.numInstances()) {
159
160
161
162 makeLeaf = true;
163 break;
164 }
165 }
166 }
167
168
169 if (makeLeaf) {
170 m_Attribute = null;
171 m_Distribution = new double[data.numClasses()];
172 Enumeration instEnum = data.enumerateInstances();
173 double sum = 0;
174 while (instEnum.hasMoreElements()) {
175 Instance inst = (Instance) instEnum.nextElement();
176 m_Distribution[(int) inst.classValue()]++;
177 sum += inst.weight();
178 }
179
180 laplaceSmooth(m_Distribution, sum, data.numClasses());
181
182 m_ClassValue = Utils.maxIndex(m_Distribution);
183 m_ClassAttribute = data.classAttribute();
184 } else {
185 m_Successors = new MyId3[m_Attribute.numValues()];
186 for (int j = 0; j < m_Attribute.numValues(); j++) {
187 m_Successors[j] = new MyId3();
188 m_Successors[j].buildClassifier(splitData[j]);
189 }
190 }
191 }
192
193 public void laplaceSmooth(double [] dist, double sum, int numClasses){
194 for(int i = 0; i < dist.length; ++i){
195 dist[i] = (dist[i] + 1)/ (sum + numClasses);
196 }
197 }
198
199 /***
200 * Classifies a given test instance using the decision tree.
201 *
202 * @param instance
203 * the instance to be classified
204 * @return the classification
205 */
206 public double classifyInstance(Instance instance) {
207 if (m_Attribute == null) {
208 return m_ClassValue;
209 } else if (instance.isMissing(m_Attribute)) {
210 try {
211
212 return super.classifyInstance(instance);
213 } catch (Exception x) {
214 x.printStackTrace();
215 Assert.UNREACHABLE();
216 return 0.;
217 }
218 } else {
219 return m_Successors[(int) instance.value(m_Attribute)].classifyInstance(instance);
220 }
221 }
222
223 /***
224 * Computes class distribution for instance using decision tree.
225 *
226 * @param instance
227 * the instance for which distribution is to be computed
228 * @return the class distribution for the given instance
229 */
230 public double[] distributionForInstance(Instance instance)
231 throws NoSupportForMissingValuesException {
232 if (m_Attribute == null) {
233 return m_Distribution;
234 } else if (instance.isMissing(m_Attribute)) {
235 double[] d = new double[0];
236 for (int i = 0; i < m_Successors.length; ++i) {
237 double[] dd = m_Successors[i].distributionForInstance(instance);
238 if (d.length == 0 && dd.length > 0) d = new double[dd.length];
239 for (int j = 0; j < d.length; ++j) {
240 d[j] += dd[j];
241 }
242 }
243 for (int j = 0; j < d.length; ++j) {
244 d[j] /= m_Successors.length;
245 }
246 return d;
247 } else {
248 return m_Successors[(int) instance.value(m_Attribute)]
249 .distributionForInstance(instance);
250 }
251 }
252
253 /***
254 * Prints the decision tree using the private toString method from below.
255 *
256 * @return a textual description of the classifier
257 */
258 public String toString() {
259 if ((m_Distribution == null) && (m_Successors == null)) {
260 return "Id3: No model built yet.";
261 }
262 return "Id3\n\n" + toString(0);
263 }
264
265 /***
266 * Computes information gain for an attribute.
267 *
268 * @param data
269 * the data for which info gain is to be computed
270 * @param att
271 * the attribute
272 * @return the information gain for the given attribute and data
273 */
274 private double computeInfoGain(Instances data, Attribute att) throws Exception {
275 double infoGain = computeEntropy(data, att);
276 Instances[] splitData = splitData(data, att);
277 for (int j = 0; j < att.numValues(); j++) {
278 if (splitDataSize[j] > 0) {
279 infoGain -= ((double) splitDataSize[j] / (double) numI)
280 * computeEntropy(splitData[j], att);
281 }
282 }
283 return infoGain;
284 }
285
286 /***
287 * Computes the entropy of a dataset.
288 *
289 * @param data
290 * the data for which entropy is to be computed
291 * @return the entropy of the data's class distribution
292 */
293 private double computeEntropy(Instances data, Attribute att) throws Exception {
294 double[] classCounts = new double[data.numClasses()];
295 Enumeration instEnum = data.enumerateInstances();
296 int numInstances = 0;
297 while (instEnum.hasMoreElements()) {
298 Instance inst = (Instance) instEnum.nextElement();
299 if (inst.isMissing(att)) continue;
300 classCounts[(int) inst.classValue()]++;
301 ++numInstances;
302 }
303 double entropy = 0;
304 for (int j = 0; j < data.numClasses(); j++) {
305 if (classCounts[j] > 0) {
306 entropy -= classCounts[j] * Utils.log2(classCounts[j]);
307 }
308 }
309 entropy /= (double) numInstances;
310 return entropy + Utils.log2(numInstances);
311 }
312 int numI;
313 int splitDataSize[];
314
315 /***
316 * Splits a dataset according to the values of a nominal attribute.
317 *
318 * @param data
319 * the data which is to be split
320 * @param att
321 * the attribute to be used for splitting
322 * @return the sets of instances produced by the split
323 */
324 private Instances[] splitData(Instances data, Attribute att) {
325 numI = 0;
326 splitDataSize = new int[att.numValues()];
327 Instances[] splitData = new Instances[att.numValues()];
328 for (int j = 0; j < att.numValues(); j++) {
329 splitData[j] = new Instances(data, data.numInstances());
330 }
331 Enumeration instEnum = data.enumerateInstances();
332 while (instEnum.hasMoreElements()) {
333 Instance inst = (Instance) instEnum.nextElement();
334 if (inst.isMissing(att)) {
335
336 for (int k = 0; k < att.numValues(); ++k) {
337 splitData[k].add(inst);
338 }
339 } else {
340 int k = (int) inst.value(att);
341 splitData[k].add(inst);
342 splitDataSize[k]++;
343 numI++;
344 }
345 }
346 return splitData;
347 }
348
349 /***
350 * Outputs a tree at a certain level.
351 *
352 * @param level
353 * the level at which the tree is to be printed
354 */
355 private String toString(int level) {
356 StringBuffer text = new StringBuffer();
357 if (m_Attribute == null) {
358 if (Instance.isMissingValue(m_ClassValue)) {
359 text.append(": null");
360 } else {
361 text.append(": " + m_ClassAttribute.value((int) m_ClassValue));
362 }
363 } else {
364 for (int j = 0; j < m_Attribute.numValues(); j++) {
365 text.append("\n");
366 for (int i = 0; i < level; i++) {
367 text.append("| ");
368 }
369 text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
370 text.append(m_Successors[j].toString(level + 1));
371 }
372 }
373 return text.toString();
374 }
375
376 /***
377 * Main method.
378 *
379 * @param args the options for the classifier
380 */
381 public static void main(String[] args) {
382 try {
383 System.out.println(Evaluation.evaluateModel(new Id3(), args));
384 } catch (Exception e) {
385 System.err.println(e.getMessage());
386 }
387 }
388 }