#region License Information /* HeuristicLab * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Algorithms.DataAnalysis; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HEAL.Attic; namespace HeuristicLab.Problems.DataAnalysis.FeatureSelection { [Item("VariableImpactBasedFeatureSelectionAlgorithm", "")] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 9999)] [StorableType("EB47CA07-6F01-4FC1-9351-54ACB1F2CF24")] public class VariableImpactBasedFeatureSelectionAlgorithm : BasicAlgorithm { public override bool SupportsPause { get { return false; } } #region Problem Type public override Type ProblemType { get { return typeof(IRegressionProblem); } } public new IRegressionProblem Problem { get { return (IRegressionProblem)base.Problem; } set { base.Problem = value; } } #endregion #region Parameter Properties private ValueParameter> AlgorithmParameter { get { return (ValueParameter>)Parameters["Algorithm"]; } } private FixedValueParameter FeaturesDropParameter { get { return (FixedValueParameter)Parameters["FeaturesDrop"]; } } #endregion #region Results Parameter #endregion #region Constructor, Cloning & Persistence public VariableImpactBasedFeatureSelectionAlgorithm() { Parameters.Add(new ValueParameter>("Algorithm", new RandomForestRegression())); Parameters.Add(new FixedValueParameter("FeaturesDrop", new PercentValue(0.2))); // ToDo: Use ResultParameters //Parameters.Add(new ResultParameter<>()); Problem = new RegressionProblem(); } [StorableConstructor] protected VariableImpactBasedFeatureSelectionAlgorithm(StorableConstructorFlag _) : base(_) { } public VariableImpactBasedFeatureSelectionAlgorithm(VariableImpactBasedFeatureSelectionAlgorithm original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new VariableImpactBasedFeatureSelectionAlgorithm(this, cloner); } #endregion protected override void Run(CancellationToken cancellationToken) { var clonedAlgorithm = (FixedDataAnalysisAlgorithm)AlgorithmParameter.Value.Clone(); var clonedProblem = (IRegressionProblem)Problem.Clone(); double featureDrop = FeaturesDropParameter.Value.Value; SetupAlgorithm(clonedAlgorithm, clonedProblem); Results.Add(new Result("Algorithm", clonedAlgorithm)); var remainingFeatures = clonedProblem.ProblemData.InputVariables.CheckedItems.Select(x => x.Value.Value).ToList(); var variableImpactsDataTable = new DataTable("VariableImpacts"); foreach (var variable in clonedProblem.ProblemData.InputVariables) { variableImpactsDataTable.Rows.Add(new DataRow(variable.Value)); } Results.Add(new Result("VariableImpacts", variableImpactsDataTable)); var selectedVariablesResult = new StringMatrix(remainingFeatures.Count, 1) { ColumnNames = new[] { "StartUp" } }; for (int i = 0; i < remainingFeatures.Count; i++) { selectedVariablesResult[i, 0] = remainingFeatures[i]; } Results.Add(new Result("SelectedFeatures", selectedVariablesResult)); var qualitiesResult = new DataTable("SolutionQualities"); qualitiesResult.Rows.Add(new DataRow("MAE Training")); qualitiesResult.Rows.Add(new DataRow("MAE Test")); qualitiesResult.Rows.Add(new DataRow("R² Training") { VisualProperties = { SecondYAxis = true } }); qualitiesResult.Rows.Add(new DataRow("R² Test") { VisualProperties = { SecondYAxis = true } }); Results.Add(new Result("Qualities", qualitiesResult)); int iteration = 0; while (remainingFeatures.Any()) { clonedAlgorithm.Start(cancellationToken); int numberOfRemainingVariables = (int)(remainingFeatures.Count * (1.0 - featureDrop)); // floor to avoid getting stuck var variableImpacts = GetVariableImpacts(clonedAlgorithm.Results).ToDictionary(x => x.Item1, x => x.Item2); remainingFeatures = variableImpacts .OrderByDescending(x => x.Value) .Take(numberOfRemainingVariables) .Select(x => x.Key) .ToList(); foreach (var row in variableImpactsDataTable.Rows) { row.Values.Add(variableImpacts.ContainsKey(row.Name) ? variableImpacts[row.Name] : double.NaN); } ((IStringConvertibleMatrix)selectedVariablesResult).Columns++; selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Select(c => c.Replace("Column", "Iteration")); //selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Concat(new[] { $"Iteration {iteration}" }); for (int i = 0; i < remainingFeatures.Count; i++) { selectedVariablesResult[i, selectedVariablesResult.Columns - 1] = remainingFeatures[i]; } var solution = clonedAlgorithm.Results.Select(r => r.Value).OfType().Single(); qualitiesResult.Rows["MAE Training"].Values.Add(solution.TrainingMeanAbsoluteError); qualitiesResult.Rows["MAE Test"].Values.Add(solution.TestMeanAbsoluteError); qualitiesResult.Rows["R² Training"].Values.Add(solution.TrainingRSquared); qualitiesResult.Rows["R² Test"].Values.Add(solution.TestRSquared); UpdateSelectedInputs(clonedProblem, remainingFeatures); iteration++; } } private static void SetupAlgorithm(FixedDataAnalysisAlgorithm algorithm, IRegressionProblem problem) { algorithm.Problem = problem; algorithm.Prepare(clearRuns: true); } private static IEnumerable> GetVariableImpacts(ResultCollection results) { //var solution = (IRegressionSolution)results["Random forest regression solution"].Value; var solution = results.Select(r => r.Value).OfType().Single(); return RegressionSolutionVariableImpactsCalculator.CalculateImpacts( solution, replacementMethod: RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle, factorReplacementMethod: RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle, dataPartition: RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training); } private void UpdateSelectedInputs(IRegressionProblem problem, List remainingFeatures) { foreach (var inputFeature in problem.ProblemData.InputVariables) { bool isRemaining = remainingFeatures.Contains(inputFeature.Value); problem.ProblemData.InputVariables.SetItemCheckedState(inputFeature, isRemaining); } } } }