#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HEAL.Attic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableType("55412E08-DAD4-4C2E-9181-C142E7EA9474")] [Item("RandomForestModelFull", "Represents a random forest for regression and classification.")] public sealed class RandomForestModelFull : ClassificationModel, IRandomForestModel { public override IEnumerable VariablesUsedForPrediction { get { return inputVariables; } } [Storable] private double[] classValues; public int NumClasses => classValues == null ? 0 : classValues.Length; [Storable] private string[] inputVariables; [Storable] public int NumberOfTrees { get; private set; } // not persisted private alglib.decisionforest randomForest; [Storable] private string RandomForestSerialized { get { alglib.dfserialize(randomForest, out var serialized); return serialized; } set { if (value != null) alglib.dfunserialize(value, out randomForest); } } [StorableConstructor] private RandomForestModelFull(StorableConstructorFlag _) : base(_) { } private RandomForestModelFull(RandomForestModelFull original, Cloner cloner) : base(original, cloner) { if (original.randomForest != null) randomForest = (alglib.decisionforest)original.randomForest.make_copy(); NumberOfTrees = original.NumberOfTrees; // following fields are immutable so we don't need to clone them inputVariables = original.inputVariables; classValues = original.classValues; } public override IDeepCloneable Clone(Cloner cloner) { return new RandomForestModelFull(this, cloner); } public RandomForestModelFull(alglib.decisionforest decisionForest, int nTrees, string targetVariable, IEnumerable inputVariables, IEnumerable classValues = null) : base(targetVariable) { this.name = ItemName; this.description = ItemDescription; randomForest = (alglib.decisionforest)decisionForest.make_copy(); NumberOfTrees = nTrees; this.inputVariables = inputVariables.ToArray(); //classValues are only use for classification models if (classValues == null) this.classValues = new double[0]; else this.classValues = classValues.ToArray(); } public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData)); } public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData)); } public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) { return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage); } public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) { if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null."); var regressionProblemData = problemData as IRegressionProblemData; if (regressionProblemData != null) return IsProblemDataCompatible(regressionProblemData, out errorMessage); var classificationProblemData = problemData as IClassificationProblemData; if (classificationProblemData != null) return IsProblemDataCompatible(classificationProblemData, out errorMessage); throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData"); } public IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) { double[,] inputData = dataset.ToArray(inputVariables, rows); RandomForestUtil.AssertInputMatrix(inputData); int n = inputData.GetLength(0); int columns = inputData.GetLength(1); double[] x = new double[columns]; double[] y = new double[1]; alglib.dfcreatebuffer(randomForest, out var buf); for (int row = 0; row < n; row++) { for (int column = 0; column < columns; column++) { x[column] = inputData[row, column]; } alglib.dftsprocess(randomForest, buf, x, ref y); // thread-safe process (as long as separate buffers are used) yield return y[0]; } } public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) { double[,] inputData = dataset.ToArray(inputVariables, rows); RandomForestUtil.AssertInputMatrix(inputData); int n = inputData.GetLength(0); int columns = inputData.GetLength(1); double[] x = new double[columns]; double[] ys = new double[this.randomForest.innerobj.ntrees]; for (int row = 0; row < n; row++) { for (int column = 0; column < columns; column++) { x[column] = inputData[row, column]; } lock (randomForest) alglib.dforest.dfprocessraw(randomForest.innerobj, x, ref ys, null); yield return ys.VariancePop(); } } public override IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) { double[,] inputData = dataset.ToArray(inputVariables, rows); RandomForestUtil.AssertInputMatrix(inputData); int n = inputData.GetLength(0); int columns = inputData.GetLength(1); double[] x = new double[columns]; double[] y = new double[NumClasses]; alglib.dfcreatebuffer(randomForest, out var buf); for (int row = 0; row < n; row++) { for (int column = 0; column < columns; column++) { x[column] = inputData[row, column]; } alglib.dftsprocess(randomForest, buf, x, ref y); // find class for with the largest probability value int maxProbClassIndex = 0; double maxProb = y[0]; for (int i = 1; i < y.Length; i++) { if (maxProb < y[i]) { maxProb = y[i]; maxProbClassIndex = i; } } yield return classValues[maxProbClassIndex]; } } public ISymbolicExpressionTree ExtractTree(int treeIdx) { var rf = randomForest; // hoping that the internal representation of alglib is stable // TREE FORMAT // W[Offs] - size of sub-array (for the tree) // node info: // W[K+0] - variable number (-1 for leaf mode) // W[K+1] - threshold (class/value for leaf node) // W[K+2] - ">=" branch index (absent for leaf node) // skip irrelevant trees int offset = 0; for (int i = 0; i < treeIdx - 1; i++) { offset = offset + (int)Math.Round(rf.innerobj.trees[offset]); } var constSy = new Constant(); var varCondSy = new VariableCondition() { IgnoreSlope = true }; var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy); var startNode = new StartSymbol().CreateTreeNode(); startNode.AddSubtree(node); var root = new ProgramRootSymbol().CreateTreeNode(); root.AddSubtree(startNode); return new SymbolicExpressionTree(root); } private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) { // alglib source for evaluation of one tree (dfprocessinternal) // offs = 0 // // Set pointer to the root // // k = offs + 1; // // // // // Navigate through the tree // // // while (true) { // if ((double)(df.trees[k]) == (double)(-1)) { // if (df.nclasses == 1) { // y[0] = y[0] + df.trees[k + 1]; // } else { // idx = (int)Math.Round(df.trees[k + 1]); // y[idx] = y[idx] + 1; // } // break; // } // if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) { // k = k + innernodewidth; // } else { // k = offs + (int)Math.Round(df.trees[k + 2]); // } // } if ((double)(trees[k]) == (double)(-1)) { var constNode = (ConstantTreeNode)constSy.CreateTreeNode(); constNode.Value = trees[k + 1]; return constNode; } else { var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode(); condNode.VariableName = inputVariables[(int)Math.Round(trees[k])]; condNode.Threshold = trees[k + 1]; condNode.Slope = double.PositiveInfinity; var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy); var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy); condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above) condNode.AddSubtree(right); return condNode; } } } }