#region License Information
/* HeuristicLab
* Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using HeuristicLab.Algorithms.DataAnalysis;
using HeuristicLab.Analysis;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Optimization;
using HeuristicLab.Parameters;
using HEAL.Attic;
namespace HeuristicLab.Problems.DataAnalysis.FeatureSelection {
[Item("VariableImpactBasedFeatureSelectionAlgorithm", "")]
[Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 9999)]
[StorableType("EB47CA07-6F01-4FC1-9351-54ACB1F2CF24")]
public class VariableImpactBasedFeatureSelectionAlgorithm : BasicAlgorithm {
public override bool SupportsPause {
get { return false; }
}
#region Problem Type
public override Type ProblemType {
get { return typeof(IRegressionProblem); }
}
public new IRegressionProblem Problem {
get { return (IRegressionProblem)base.Problem; }
set { base.Problem = value; }
}
#endregion
#region Parameter Properties
private ValueParameter> AlgorithmParameter {
get { return (ValueParameter>)Parameters["Algorithm"]; }
}
private FixedValueParameter FeaturesDropParameter {
get { return (FixedValueParameter)Parameters["FeaturesDrop"]; }
}
#endregion
#region Results Parameter
#endregion
#region Constructor, Cloning & Persistence
public VariableImpactBasedFeatureSelectionAlgorithm() {
Parameters.Add(new ValueParameter>("Algorithm", new RandomForestRegression()));
Parameters.Add(new FixedValueParameter("FeaturesDrop", new PercentValue(0.2)));
// ToDo: Use ResultParameters
//Parameters.Add(new ResultParameter<>());
Problem = new RegressionProblem();
}
[StorableConstructor]
protected VariableImpactBasedFeatureSelectionAlgorithm(StorableConstructorFlag _)
: base(_) { }
public VariableImpactBasedFeatureSelectionAlgorithm(VariableImpactBasedFeatureSelectionAlgorithm original, Cloner cloner)
: base(original, cloner) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new VariableImpactBasedFeatureSelectionAlgorithm(this, cloner);
}
#endregion
protected override void Run(CancellationToken cancellationToken) {
var clonedAlgorithm = (FixedDataAnalysisAlgorithm)AlgorithmParameter.Value.Clone();
var clonedProblem = (IRegressionProblem)Problem.Clone();
double featureDrop = FeaturesDropParameter.Value.Value;
SetupAlgorithm(clonedAlgorithm, clonedProblem);
Results.Add(new Result("Algorithm", clonedAlgorithm));
var remainingFeatures = clonedProblem.ProblemData.InputVariables.CheckedItems.Select(x => x.Value.Value).ToList();
var variableImpactsDataTable = new DataTable("VariableImpacts");
foreach (var variable in clonedProblem.ProblemData.InputVariables) {
variableImpactsDataTable.Rows.Add(new DataRow(variable.Value));
}
Results.Add(new Result("VariableImpacts", variableImpactsDataTable));
var selectedVariablesResult = new StringMatrix(remainingFeatures.Count, 1) {
ColumnNames = new[] { "StartUp" }
};
for (int i = 0; i < remainingFeatures.Count; i++) {
selectedVariablesResult[i, 0] = remainingFeatures[i];
}
Results.Add(new Result("SelectedFeatures", selectedVariablesResult));
var qualitiesResult = new DataTable("SolutionQualities");
qualitiesResult.Rows.Add(new DataRow("MAE Training"));
qualitiesResult.Rows.Add(new DataRow("MAE Test"));
qualitiesResult.Rows.Add(new DataRow("R² Training") { VisualProperties = { SecondYAxis = true } });
qualitiesResult.Rows.Add(new DataRow("R² Test") { VisualProperties = { SecondYAxis = true } });
Results.Add(new Result("Qualities", qualitiesResult));
int iteration = 0;
while (remainingFeatures.Any()) {
clonedAlgorithm.Start(cancellationToken);
int numberOfRemainingVariables = (int)(remainingFeatures.Count * (1.0 - featureDrop)); // floor to avoid getting stuck
var variableImpacts = GetVariableImpacts(clonedAlgorithm.Results).ToDictionary(x => x.Item1, x => x.Item2);
remainingFeatures = variableImpacts
.OrderByDescending(x => x.Value)
.Take(numberOfRemainingVariables)
.Select(x => x.Key)
.ToList();
foreach (var row in variableImpactsDataTable.Rows) {
row.Values.Add(variableImpacts.ContainsKey(row.Name) ? variableImpacts[row.Name] : double.NaN);
}
((IStringConvertibleMatrix)selectedVariablesResult).Columns++;
selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Select(c => c.Replace("Column", "Iteration"));
//selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Concat(new[] { $"Iteration {iteration}" });
for (int i = 0; i < remainingFeatures.Count; i++) {
selectedVariablesResult[i, selectedVariablesResult.Columns - 1] = remainingFeatures[i];
}
var solution = clonedAlgorithm.Results.Select(r => r.Value).OfType().Single();
qualitiesResult.Rows["MAE Training"].Values.Add(solution.TrainingMeanAbsoluteError);
qualitiesResult.Rows["MAE Test"].Values.Add(solution.TestMeanAbsoluteError);
qualitiesResult.Rows["R² Training"].Values.Add(solution.TrainingRSquared);
qualitiesResult.Rows["R² Test"].Values.Add(solution.TestRSquared);
UpdateSelectedInputs(clonedProblem, remainingFeatures);
iteration++;
}
}
private static void SetupAlgorithm(FixedDataAnalysisAlgorithm algorithm, IRegressionProblem problem) {
algorithm.Problem = problem;
algorithm.Prepare(clearRuns: true);
}
private static IEnumerable> GetVariableImpacts(ResultCollection results) {
//var solution = (IRegressionSolution)results["Random forest regression solution"].Value;
var solution = results.Select(r => r.Value).OfType().Single();
return RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
solution,
replacementMethod: RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle,
factorReplacementMethod: RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle,
dataPartition: RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training);
}
private void UpdateSelectedInputs(IRegressionProblem problem, List remainingFeatures) {
foreach (var inputFeature in problem.ProblemData.InputVariables) {
bool isRemaining = remainingFeatures.Contains(inputFeature.Value);
problem.ProblemData.InputVariables.SetItemCheckedState(inputFeature, isRemaining);
}
}
}
}