#ifndef IGMDK_RANDOM_FOREST_H #define IGMDK_RANDOM_FOREST_H #include "ClassificationCommon.h" #include "DecisionTree.h" #include "../Utils/Vector.h" #include "../RandomNumberGeneration/Random.h" #include namespace igmdk{ class RandomForest { Vector forest; int nClasses; public: template RandomForest(DATA const& data, int nTrees = 300): nClasses(findNClasses(data)) { assert(data.getSize() > 1); for(int i = 0; i < nTrees; ++i) { PermutedData resample(data); for(int j = 0; j < data.getSize(); ++j) resample.addIndex(GlobalRNG().mod(data.getSize())); forest.append(DecisionTree(resample, 0, true)); } } template static int classifyWork(NUMERIC_X const& x, ENSEMBLE const& e, int nClasses) { Vector counts(nClasses, 0); for(int i = 0; i < e.getSize(); ++i) ++counts[e[i].predict(x)]; return argMax(counts.getArray(), counts.getSize()); } int predict(NUMERIC_X const& x)const {return classifyWork(x, forest, nClasses);} Vector classifyProbs(NUMERIC_X const& x)const { Vector counts(nClasses, 0); for(int i = 0; i < forest.getSize(); ++i) ++counts[forest[i].predict(x)]; normalizeProbs(counts); return counts; } }; }//end namespace #endif