View Javadoc

1   /*
2    * Created on Nov 6, 2004
3    *
4    * TODO To change the template for this generated file go to
5    * Window - Preferences - Java - Code Style - Code Templates
6    */
7   package net.sf.bddbddb.order;
8   
9   import java.util.Arrays;
10  import java.util.Random;
11  
12  import jwutil.util.Assert;
13  import weka.core.Instance;
14  import weka.core.Instances;
15  import weka.core.NoSupportForMissingValuesException;
16  import weka.core.Utils;
17  
18  /***
19   * @author Administrator
20   *
21   * TODO To change the template for this generated type comment go to
22   * Window - Preferences - Java - Code Style - Code Templates
23   */
24  public class BaggedId3 extends ClassProbabilityEstimator {
25  
26      /***
27       * Version ID for serialization.
28       */
29      private static final long serialVersionUID = 3256726195109115186L;
30      
31      public final static int NUM_TREES = 10;
32      MyId3 [] trees;
33      double [] weights;
34      Random random = new Random(System.currentTimeMillis());
35      int numClasses;
36      Instances origData;
37      public BaggedId3(){
38          trees = new MyId3[NUM_TREES];
39          weights = new double[NUM_TREES];
40          Arrays.fill(weights,1); //trees weight equally
41      }
42      
43      public void setWeights(double [] weights){
44          this.weights = weights;
45      }
46      
47      public void setWeight(int index, double value){
48          weights[index] = value;
49      }
50      /* (non-Javadoc)
51       * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
52       */
53      public void buildClassifier(Instances data) throws Exception {
54          numClasses = data.classAttribute().numValues();
55          for(int i = 0; i < NUM_TREES; ++i){
56           Instances newData = data.resample(random);  //random sample with replacement
57           newData.setClassIndex(data.classIndex());
58           trees[i] = new MyId3();
59           trees[i].buildClassifier(newData);
60          }
61          origData = data;
62      }
63      
64      public double classifyInstance(Instance instance){
65          double[] votes = new double[numClasses];
66          for(int i = 0; i < NUM_TREES; ++i){
67              double vote = trees[i].classifyInstance(instance);
68              votes[(int) vote] += weights[i]; 
69          }
70          return Utils.maxIndex(votes); //simple majority
71      }
72  
73      public double classProbability(Instance instance, double targetClass){
74          Assert._assert(targetClass >= 0);
75          Assert._assert(targetClass < numClasses);
76              try {
77                  return distributionForInstance(instance)[(int) targetClass];
78              } catch (NoSupportForMissingValuesException e) {
79                  e.printStackTrace();
80              }
81              
82              return Double.NaN;
83      }
84      public double[] distributionForInstance(Instance instance) throws NoSupportForMissingValuesException{
85          double sum = 0;
86          double [] distribution = new double[numClasses];
87          for(int i = 0; i < NUM_TREES; ++i){
88              double [] treeDist = trees[i].distributionForInstance(instance);
89              for(int j = 0; j <numClasses; ++j){
90                  distribution[j]+= treeDist[j] * weights[i];
91              }
92              sum += weights[i];
93          }
94          Assert._assert(sum != 0, "Sum of Weights is zero");
95          for(int i = 0; i < numClasses; ++i)
96              distribution[i] /= sum;
97          
98          return distribution;
99      }
100     
101     public double classVariance(Instance instance, double targetClass){
102         Assert._assert(targetClass >= 0);
103         Assert._assert(targetClass < numClasses);
104             try {
105                 return varianceForInstance(instance)[(int) targetClass];
106             } catch (NoSupportForMissingValuesException e) {
107                 e.printStackTrace();
108             }
109             
110             return Double.NaN;
111     }
112     public double[] varianceForInstance(Instance instance) throws NoSupportForMissingValuesException{
113         
114         //should this be weighted?
115         double [][] classProbabilities = new double[NUM_TREES + 1][];
116         classProbabilities[NUM_TREES] = new double[numClasses];
117         double [] variance = new double[numClasses];
118         for(int i = 0; i < NUM_TREES; ++i){
119             double [] treeDist = trees[i].distributionForInstance(instance);
120             classProbabilities[i] = treeDist;
121             for(int j = 0; j <numClasses; ++j){
122                 classProbabilities[NUM_TREES][j] += treeDist[j];
123             }
124         }
125         for(int j = 0; j <numClasses; ++j)
126             classProbabilities[NUM_TREES][j] /= NUM_TREES;
127         
128         for(int i = 0; i < NUM_TREES; ++i){
129            for(int j = 0; j <numClasses; ++j){
130              double diff = Math.abs(classProbabilities[i][j] - classProbabilities[NUM_TREES][j]);
131              variance[j] += diff * diff;
132            }
133         }
134         for(int j = 0; j <numClasses; ++j)
135             variance[j] /= NUM_TREES;
136        
137         return variance;
138     }
139     
140     public Instances getData(){ return origData; }
141      
142 }