1
2
3
4
5
6
7 package net.sf.bddbddb.order;
8
9 import java.util.Arrays;
10 import java.util.Random;
11
12 import jwutil.util.Assert;
13 import weka.core.Instance;
14 import weka.core.Instances;
15 import weka.core.NoSupportForMissingValuesException;
16 import weka.core.Utils;
17
18 /***
19 * @author Administrator
20 *
21 * TODO To change the template for this generated type comment go to
22 * Window - Preferences - Java - Code Style - Code Templates
23 */
24 public class BaggedId3 extends ClassProbabilityEstimator {
25
26 /***
27 * Version ID for serialization.
28 */
29 private static final long serialVersionUID = 3256726195109115186L;
30
31 public final static int NUM_TREES = 10;
32 MyId3 [] trees;
33 double [] weights;
34 Random random = new Random(System.currentTimeMillis());
35 int numClasses;
36 Instances origData;
37 public BaggedId3(){
38 trees = new MyId3[NUM_TREES];
39 weights = new double[NUM_TREES];
40 Arrays.fill(weights,1);
41 }
42
43 public void setWeights(double [] weights){
44 this.weights = weights;
45 }
46
47 public void setWeight(int index, double value){
48 weights[index] = value;
49 }
50
51
52
53 public void buildClassifier(Instances data) throws Exception {
54 numClasses = data.classAttribute().numValues();
55 for(int i = 0; i < NUM_TREES; ++i){
56 Instances newData = data.resample(random);
57 newData.setClassIndex(data.classIndex());
58 trees[i] = new MyId3();
59 trees[i].buildClassifier(newData);
60 }
61 origData = data;
62 }
63
64 public double classifyInstance(Instance instance){
65 double[] votes = new double[numClasses];
66 for(int i = 0; i < NUM_TREES; ++i){
67 double vote = trees[i].classifyInstance(instance);
68 votes[(int) vote] += weights[i];
69 }
70 return Utils.maxIndex(votes);
71 }
72
73 public double classProbability(Instance instance, double targetClass){
74 Assert._assert(targetClass >= 0);
75 Assert._assert(targetClass < numClasses);
76 try {
77 return distributionForInstance(instance)[(int) targetClass];
78 } catch (NoSupportForMissingValuesException e) {
79 e.printStackTrace();
80 }
81
82 return Double.NaN;
83 }
84 public double[] distributionForInstance(Instance instance) throws NoSupportForMissingValuesException{
85 double sum = 0;
86 double [] distribution = new double[numClasses];
87 for(int i = 0; i < NUM_TREES; ++i){
88 double [] treeDist = trees[i].distributionForInstance(instance);
89 for(int j = 0; j <numClasses; ++j){
90 distribution[j]+= treeDist[j] * weights[i];
91 }
92 sum += weights[i];
93 }
94 Assert._assert(sum != 0, "Sum of Weights is zero");
95 for(int i = 0; i < numClasses; ++i)
96 distribution[i] /= sum;
97
98 return distribution;
99 }
100
101 public double classVariance(Instance instance, double targetClass){
102 Assert._assert(targetClass >= 0);
103 Assert._assert(targetClass < numClasses);
104 try {
105 return varianceForInstance(instance)[(int) targetClass];
106 } catch (NoSupportForMissingValuesException e) {
107 e.printStackTrace();
108 }
109
110 return Double.NaN;
111 }
112 public double[] varianceForInstance(Instance instance) throws NoSupportForMissingValuesException{
113
114
115 double [][] classProbabilities = new double[NUM_TREES + 1][];
116 classProbabilities[NUM_TREES] = new double[numClasses];
117 double [] variance = new double[numClasses];
118 for(int i = 0; i < NUM_TREES; ++i){
119 double [] treeDist = trees[i].distributionForInstance(instance);
120 classProbabilities[i] = treeDist;
121 for(int j = 0; j <numClasses; ++j){
122 classProbabilities[NUM_TREES][j] += treeDist[j];
123 }
124 }
125 for(int j = 0; j <numClasses; ++j)
126 classProbabilities[NUM_TREES][j] /= NUM_TREES;
127
128 for(int i = 0; i < NUM_TREES; ++i){
129 for(int j = 0; j <numClasses; ++j){
130 double diff = Math.abs(classProbabilities[i][j] - classProbabilities[NUM_TREES][j]);
131 variance[j] += diff * diff;
132 }
133 }
134 for(int j = 0; j <numClasses; ++j)
135 variance[j] /= NUM_TREES;
136
137 return variance;
138 }
139
140 public Instances getData(){ return origData; }
141
142 }