View Javadoc

1   // MyId3.java, created Oct 31, 2004 2:13:00 AM by joewhaley
2   // Copyright (C) 2004 John Whaley <jwhaley@alum.mit.edu>
3   // Licensed under the terms of the GNU LGPL; see COPYING for details.
4   package net.sf.bddbddb.order;
5   
6   import java.util.Arrays;
7   import java.util.Enumeration;
8   import java.util.Iterator;
9   import java.util.LinkedList;
10  import java.util.List;
11  
12  import jwutil.util.Assert;
13  import net.sf.bddbddb.FindBestDomainOrder;
14  import weka.classifiers.Classifier;
15  import weka.classifiers.Evaluation;
16  import weka.classifiers.trees.Id3;
17  import weka.core.Attribute;
18  import weka.core.Instance;
19  import weka.core.Instances;
20  import weka.core.NoSupportForMissingValuesException;
21  import weka.core.UnsupportedAttributeTypeException;
22  import weka.core.UnsupportedClassTypeException;
23  import weka.core.Utils;
24  
25  /***
26   * Class implementing an Id3 decision tree classifier. This version differs from
27   * the weka one in that it supports missing attributes.
28   * 
29   * @author Eibe Frank (eibe@cs.waikato.ac.nz)
30   * @author John Whaley
31   * @version $Revision: 531 $
32   */
33  public class MyId3 extends Classifier {
34      /***
35       * Version ID for serialization.
36       */
37      private static final long serialVersionUID = 3258129154733322289L;
38      
39      /*** The node's successors. */
40      private MyId3[] m_Successors;
41      /*** Attribute used for splitting. */
42      private Attribute m_Attribute; // not set for leaf.
43      /*** Class value if node is leaf. */
44      private double m_ClassValue;
45      /*** Class distribution if node is leaf. */
46      private double[] m_Distribution;
47      /*** Class attribute of dataset. */
48      private Attribute m_ClassAttribute;
49  
50      public boolean getAttribCombos(Instances i, double cv) {
51          List r = getAttribCombos(i.numAttributes(), cv);
52          if (r == null) return false;
53          for (Iterator ii = r.iterator(); ii.hasNext(); ) {
54              double[] d = (double[]) ii.next();
55              i.add(new Instance(1., d));
56          }
57          return true;
58      }
59      
60      public List getAttribCombos(int nAttribs, double cv) {
61          if (m_Attribute == null) {
62              if (FindBestDomainOrder.compare(m_ClassValue, cv) == 0) {
63                  List result = new LinkedList();
64                  double[] i = new double[nAttribs];
65                  Arrays.fill(i, Double.NaN);
66                  result.add(i);
67                  return result;
68              } else {
69                  return null;
70              }
71          } else {
72              List result = new LinkedList(); 
73              for (int i = 0; i < m_Successors.length; ++i) {
74                  List c = m_Successors[i].getAttribCombos(nAttribs, cv);
75                  if (c != null) {
76                      int index = m_Attribute.index();
77                      for (Iterator j = c.iterator(); j.hasNext(); ) {
78                          double[] d = (double[]) j.next();
79                          d[index] = i;
80                      }
81                      result.addAll(c);
82                  }
83              }
84              if (result.isEmpty()) return null;
85              else return result;
86          }
87      }
88      
89      /***
90       * Returns a string describing the classifier.
91       * 
92       * @return a description suitable for the GUI.
93       */
94      public String globalInfo() {
95          return "Class for constructing an unpruned decision tree based on the ID3 "
96              + "algorithm. Can only deal with nominal attributes. "
97              + "Empty leaves may result in unclassified instances. For more "
98              + "information see: \n\n" + " R. Quinlan (1986). \"Induction of decision "
99              + "trees\". Machine Learning. Vol.1, No.1, pp. 81-106";
100     }
101 
102     /***
103      * Builds Id3 decision tree classifier.
104      * 
105      * @param data
106      *            the training data
107      * @exception Exception
108      *                if classifier can't be built successfully
109      */
110     public void buildClassifier(Instances data) throws Exception {
111         if (!data.classAttribute().isNominal()) {
112             throw new UnsupportedClassTypeException("Id3: nominal class, please.");
113         }
114         Enumeration enumAtt = data.enumerateAttributes();
115         while (enumAtt.hasMoreElements()) {
116             if (!((Attribute) enumAtt.nextElement()).isNominal()) {
117                 throw new UnsupportedAttributeTypeException("Id3: only nominal "
118                     + "attributes, please.");
119             }
120         }
121         data = new Instances(data);
122         data.deleteWithMissingClass();
123         makeTree(data);
124     }
125 
126     /***
127      * Method for building an Id3 tree.
128      * 
129      * @param data
130      *            the training data
131      * @exception Exception
132      *                if decision tree can't be built successfully
133      */
134     private void makeTree(Instances data) throws Exception {
135         // Check if no instances have reached this node.
136         if (data.numInstances() == 0) {
137             m_Attribute = null;
138             m_ClassValue = Instance.missingValue();
139             m_Distribution = new double[data.numClasses()];
140             double sum = 0;
141             laplaceSmooth(m_Distribution, sum, data.numClasses());
142             return;
143         }
144         // Compute attribute with maximum information gain.
145         double[] infoGains = new double[data.numAttributes()];
146         Enumeration attEnum = data.enumerateAttributes();
147         while (attEnum.hasMoreElements()) {
148             Attribute att = (Attribute) attEnum.nextElement();
149             infoGains[att.index()] = computeInfoGain(data, att);
150         }
151         m_Attribute = data.attribute(Utils.maxIndex(infoGains));
152         boolean makeLeaf;
153         makeLeaf = Utils.eq(infoGains[m_Attribute.index()], 0);
154         Instances[] splitData = null;
155         if (!makeLeaf) {
156             splitData = splitData(data, m_Attribute);
157             for (int i = 0; i < splitData.length; ++i) {
158                 if (splitData[i].numInstances() == data.numInstances()) {
159                     //System.out.println("When splitting on attrib
160                     // "+m_Attribute+", child "+i+" is same size as current,
161                     // making into leaf.");
162                     makeLeaf = true;
163                     break;
164                 }
165             }
166         }
167         // Make leaf if information gain is zero.
168         // Otherwise create successors.
169         if (makeLeaf) {
170             m_Attribute = null;
171             m_Distribution = new double[data.numClasses()];
172             Enumeration instEnum = data.enumerateInstances();
173             double sum = 0;
174             while (instEnum.hasMoreElements()) {
175                 Instance inst = (Instance) instEnum.nextElement();
176                 m_Distribution[(int) inst.classValue()]++;
177                 sum += inst.weight();
178             }
179             //laplace smooth the distribution instead
180             laplaceSmooth(m_Distribution, sum, data.numClasses());
181             //Utils.normalize(m_Distribution);
182             m_ClassValue = Utils.maxIndex(m_Distribution);
183             m_ClassAttribute = data.classAttribute();
184         } else {
185             m_Successors = new MyId3[m_Attribute.numValues()];
186             for (int j = 0; j < m_Attribute.numValues(); j++) {
187                 m_Successors[j] = new MyId3();
188                 m_Successors[j].buildClassifier(splitData[j]);
189             }
190         }
191     }
192     
193     public void laplaceSmooth(double [] dist, double sum, int numClasses){
194         for(int i = 0; i < dist.length; ++i){
195             dist[i] = (dist[i] + 1)/ (sum + numClasses);
196         }
197     }
198 
199     /***
200      * Classifies a given test instance using the decision tree.
201      * 
202      * @param instance
203      *            the instance to be classified
204      * @return the classification
205      */
206     public double classifyInstance(Instance instance) {
207         if (m_Attribute == null) {
208             return m_ClassValue;
209         } else if (instance.isMissing(m_Attribute)) {
210             try {
211                 // Use superclass implementation, which uses distributionForInstance.
212                 return super.classifyInstance(instance);
213             } catch (Exception x) {
214                 x.printStackTrace();
215                 Assert.UNREACHABLE();
216                 return 0.;
217             }
218         } else {
219             return m_Successors[(int) instance.value(m_Attribute)].classifyInstance(instance);
220         }
221     }
222 
223     /***
224      * Computes class distribution for instance using decision tree.
225      * 
226      * @param instance
227      *            the instance for which distribution is to be computed
228      * @return the class distribution for the given instance
229      */
230     public double[] distributionForInstance(Instance instance)
231         throws NoSupportForMissingValuesException {
232         if (m_Attribute == null) {
233             return m_Distribution;
234         } else if (instance.isMissing(m_Attribute)) {
235             double[] d = new double[0];
236             for (int i = 0; i < m_Successors.length; ++i) {
237                 double[] dd = m_Successors[i].distributionForInstance(instance);
238                 if (d.length == 0 && dd.length > 0) d = new double[dd.length];
239                 for (int j = 0; j < d.length; ++j) {
240                     d[j] += dd[j];
241                 }
242             }
243             for (int j = 0; j < d.length; ++j) {
244                 d[j] /= m_Successors.length;
245             }
246             return d;
247         } else {
248             return m_Successors[(int) instance.value(m_Attribute)]
249                 .distributionForInstance(instance);
250         }
251     }
252 
253     /***
254      * Prints the decision tree using the private toString method from below.
255      * 
256      * @return a textual description of the classifier
257      */
258     public String toString() {
259         if ((m_Distribution == null) && (m_Successors == null)) {
260             return "Id3: No model built yet.";
261         }
262         return "Id3\n\n" + toString(0);
263     }
264 
265     /***
266      * Computes information gain for an attribute.
267      * 
268      * @param data
269      *            the data for which info gain is to be computed
270      * @param att
271      *            the attribute
272      * @return the information gain for the given attribute and data
273      */
274     private double computeInfoGain(Instances data, Attribute att) throws Exception {
275         double infoGain = computeEntropy(data, att);
276         Instances[] splitData = splitData(data, att);
277         for (int j = 0; j < att.numValues(); j++) {
278             if (splitDataSize[j] > 0) {
279                 infoGain -= ((double) splitDataSize[j] / (double) numI)
280                     * computeEntropy(splitData[j], att);
281             }
282         }
283         return infoGain;
284     }
285 
286     /***
287      * Computes the entropy of a dataset.
288      * 
289      * @param data
290      *            the data for which entropy is to be computed
291      * @return the entropy of the data's class distribution
292      */
293     private double computeEntropy(Instances data, Attribute att) throws Exception {
294         double[] classCounts = new double[data.numClasses()];
295         Enumeration instEnum = data.enumerateInstances();
296         int numInstances = 0;
297         while (instEnum.hasMoreElements()) {
298             Instance inst = (Instance) instEnum.nextElement();
299             if (inst.isMissing(att)) continue;
300             classCounts[(int) inst.classValue()]++;
301             ++numInstances;
302         }
303         double entropy = 0;
304         for (int j = 0; j < data.numClasses(); j++) {
305             if (classCounts[j] > 0) {
306                 entropy -= classCounts[j] * Utils.log2(classCounts[j]);
307             }
308         }
309         entropy /= (double) numInstances;
310         return entropy + Utils.log2(numInstances);
311     }
312     int numI;
313     int splitDataSize[];
314 
315     /***
316      * Splits a dataset according to the values of a nominal attribute.
317      * 
318      * @param data
319      *            the data which is to be split
320      * @param att
321      *            the attribute to be used for splitting
322      * @return the sets of instances produced by the split
323      */
324     private Instances[] splitData(Instances data, Attribute att) {
325         numI = 0;
326         splitDataSize = new int[att.numValues()];
327         Instances[] splitData = new Instances[att.numValues()];
328         for (int j = 0; j < att.numValues(); j++) {
329             splitData[j] = new Instances(data, data.numInstances());
330         }
331         Enumeration instEnum = data.enumerateInstances();
332         while (instEnum.hasMoreElements()) {
333             Instance inst = (Instance) instEnum.nextElement();
334             if (inst.isMissing(att)) {
335                 // Add to all children.
336                 for (int k = 0; k < att.numValues(); ++k) {
337                     splitData[k].add(inst);
338                 }
339             } else {
340                 int k = (int) inst.value(att);
341                 splitData[k].add(inst);
342                 splitDataSize[k]++;
343                 numI++;
344             }
345         }
346         return splitData;
347     }
348 
349     /***
350      * Outputs a tree at a certain level.
351      * 
352      * @param level
353      *            the level at which the tree is to be printed
354      */
355     private String toString(int level) {
356         StringBuffer text = new StringBuffer();
357         if (m_Attribute == null) {
358             if (Instance.isMissingValue(m_ClassValue)) {
359                 text.append(": null");
360             } else {
361                 text.append(": " + m_ClassAttribute.value((int) m_ClassValue));
362             }
363         } else {
364             for (int j = 0; j < m_Attribute.numValues(); j++) {
365                 text.append("\n");
366                 for (int i = 0; i < level; i++) {
367                     text.append("|  ");
368                 }
369                 text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
370                 text.append(m_Successors[j].toString(level + 1));
371             }
372         }
373         return text.toString();
374     }
375 
376     /***
377      * Main method.
378      *
379      * @param args the options for the classifier
380      */
381     public static void main(String[] args) {
382         try {
383             System.out.println(Evaluation.evaluateModel(new Id3(), args));
384         } catch (Exception e) {
385             System.err.println(e.getMessage());
386         }
387     }
388 }