using System; using System.Threading; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HEAL.Attic; using HeuristicLab.Algorithms.DataAnalysis.Glmnet; using HeuristicLab.Problems.DataAnalysis; using System.Collections.Generic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; using System.Collections; namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction { [Item(Name = "FastFunctionExtraction", Description = "Implementation of the Fast Function Extraction (FFX) algorithm in C#.")] [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)] [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm { #region constants private static readonly HashSet exponents = new HashSet { -1.0, -0.5, +0.5, +1.0 }; private static readonly HashSet nonlinFuncs = new HashSet { NonlinearOperator.Abs, NonlinearOperator.Log, NonlinearOperator.None }; private static readonly double minHingeThr = 0.2; private static readonly double maxHingeThr = 0.9; private static readonly int numHingeThrs = 5; private const string ConsiderInteractionsParameterName = "Consider Interactions"; private const string ConsiderDenominationParameterName = "Consider Denomination"; private const string ConsiderExponentiationParameterName = "Consider Exponentiation"; private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions"; private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear Functions"; private const string LambdaParameterName = "Elastic Net Lambda"; private const string PenaltyParameterName = "Elastic Net Penalty"; private const string MaxNumBasisFuncsParameterName = "Maximum Number of Basis Functions"; #endregion #region parameters public IValueParameter ConsiderInteractionsParameter { get { return (IValueParameter)Parameters[ConsiderInteractionsParameterName]; } } public IValueParameter ConsiderDenominationsParameter { get { return (IValueParameter)Parameters[ConsiderDenominationParameterName]; } } public IValueParameter ConsiderExponentiationsParameter { get { return (IValueParameter)Parameters[ConsiderExponentiationParameterName]; } } public IValueParameter ConsiderNonlinearFuncsParameter { get { return (IValueParameter)Parameters[ConsiderNonlinearFuncsParameterName]; } } public IValueParameter ConsiderHingeFuncsParameter { get { return (IValueParameter)Parameters[ConsiderHingeFuncsParameterName]; } } public IValueParameter PenaltyParameter { get { return (IValueParameter)Parameters[PenaltyParameterName]; } } public IValueParameter LambdaParameter { get { return (IValueParameter)Parameters[LambdaParameterName]; } } public IValueParameter MaxNumBasisFuncsParameter { get { return (IValueParameter)Parameters[MaxNumBasisFuncsParameterName]; } } #endregion #region properties public bool ConsiderInteractions { get { return ConsiderInteractionsParameter.Value.Value; } set { ConsiderInteractionsParameter.Value.Value = value; } } public bool ConsiderDenominations { get { return ConsiderDenominationsParameter.Value.Value; } set { ConsiderDenominationsParameter.Value.Value = value; } } public bool ConsiderExponentiations { get { return ConsiderExponentiationsParameter.Value.Value; } set { ConsiderExponentiationsParameter.Value.Value = value; } } public bool ConsiderNonlinearFunctions { get { return ConsiderNonlinearFuncsParameter.Value.Value; } set { ConsiderNonlinearFuncsParameter.Value.Value = value; } } public bool ConsiderHingeFunctions { get { return ConsiderHingeFuncsParameter.Value.Value; } set { ConsiderHingeFuncsParameter.Value.Value = value; } } public double Penalty { get { return PenaltyParameter.Value.Value; } set { PenaltyParameter.Value.Value = value; } } public DoubleValue Lambda { get { return LambdaParameter.Value; } set { LambdaParameter.Value = value; } } public int MaxNumBasisFuncs { get { return MaxNumBasisFuncsParameter.Value.Value; } set { MaxNumBasisFuncsParameter.Value.Value = value; } } #endregion #region ctor [StorableConstructor] private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { } public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) { } public FastFunctionExtraction() : base() { base.Problem = new RegressionProblem(); Parameters.Add(new ValueParameter(ConsiderInteractionsParameterName, "True if you want the models to include interactions, otherwise false.", new BoolValue(true))); Parameters.Add(new ValueParameter(ConsiderDenominationParameterName, "True if you want the models to include denominations, otherwise false.", new BoolValue(true))); Parameters.Add(new ValueParameter(ConsiderExponentiationParameterName, "True if you want the models to include exponentiation, otherwise false.", new BoolValue(true))); Parameters.Add(new ValueParameter(ConsiderNonlinearFuncsParameterName, "True if you want the models to include nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true))); Parameters.Add(new ValueParameter(ConsiderHingeFuncsParameterName, "True if you want the models to include Hinge Functions, otherwise false.", new BoolValue(true))); Parameters.Add(new ValueParameter(MaxNumBasisFuncsParameterName, "Set how many basis functions the models can have at most. if Max Num Basis Funcs is negative => no restriction on size", new IntValue(20))); Parameters.Add(new OptionalValueParameter(LambdaParameterName, "Optional: the value of lambda for which to calculate an elastic-net solution. lambda == null => calculate the whole path of all lambdas")); Parameters.Add(new FixedValueParameter(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.05))); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { } public override IDeepCloneable Clone(Cloner cloner) { return new FastFunctionExtraction(this, cloner); } #endregion public override Type ProblemType { get { return typeof(RegressionProblem); } } public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } } protected override void Run(CancellationToken cancellationToken) { var models = Fit(Problem.ProblemData, Penalty, out var numBases, ConsiderExponentiations, ConsiderNonlinearFunctions, ConsiderInteractions, ConsiderDenominations, ConsiderHingeFunctions, MaxNumBasisFuncs ); int i = 0; var numBasesArr = numBases.ToArray(); var solutionsArr = new List(); foreach (var model in models) { Results.Add(new Result( "Num Bases: " + numBasesArr[i++], model )); solutionsArr.Add(new SymbolicRegressionSolution(model, Problem.ProblemData)); } Results.Add(new Result("Model Accuracies", new ItemCollection(solutionsArr))); } public static IEnumerable Fit(IRegressionProblemData data, double elnetPenalty, out IEnumerable numBases, bool exp = true, bool nonlinFuncs = true, bool interactions = true, bool denoms = false, bool hingeFuncs = true, int maxNumBases = -1) { var approaches = CreateApproaches(interactions, denoms, exp, nonlinFuncs, hingeFuncs, maxNumBases, elnetPenalty); var allFFXModels = approaches .SelectMany(approach => CreateFFXModels(data, approach)).ToList(); // Final Pareto filter over all generated models from all different approaches var nondominatedFFXModels = NondominatedModels(data, allFFXModels); numBases = nondominatedFFXModels .Select(ffxModel => ffxModel.NumBases).ToArray(); return nondominatedFFXModels.Select(ffxModel => ffxModel.ToSymbolicRegressionModel(data.TargetVariable)); } private static IEnumerable NondominatedModels(IRegressionProblemData data, IEnumerable ffxModels) { var numBases = ffxModels.Select(ffxModel => (double)ffxModel.NumBases).ToArray(); var errors = ffxModels.Select(ffxModel => { var originalValues = data.TargetVariableTestValues.ToArray(); var estimatedValues = ffxModel.Simulate(data, data.TestIndices); // do not create a regressionSolution here for better performance: // RegressionSolutions calculate all kinds of errors when calling the ctor, but we only need testMSE var testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalValues, estimatedValues, out var state); if (state != OnlineCalculatorError.None) throw new ArrayTypeMismatchException("could not calculate TestMSE"); return testMSE; }).ToArray(); int n = numBases.Length; double[][] qualities = new double[n][]; for (int i = 0; i < n; i++) { qualities[i] = new double[2]; qualities[i][0] = numBases[i]; qualities[i][1] = errors[i]; } return DominationCalculator.CalculateBestParetoFront(ffxModels.ToArray(), qualities, new bool[] { false, false }) .Select(tuple => tuple.Item1).OrderBy(ffxModel => ffxModel.NumBases); } // Build FFX models private static IEnumerable CreateFFXModels(IRegressionProblemData data, Approach approach) { // FFX Step 1 var basisFunctions = BFUtils.CreateBasisFunctions(data, approach).ToArray(); // FFX Step 2 var funcsArr = basisFunctions.ToArray(); var elnetData = BFUtils.PrepareData(data, funcsArr); var normalizedElnetData = BFUtils.Normalize(elnetData, out var X_avgs, out var X_stds, out var y_avg, out var y_std); ElasticNetLinearRegression.RunElasticNetLinearRegression(normalizedElnetData, approach.ElasticNetPenalty, out var _, out var _, out var _, out var candidateCoeffsNorm, out var interceptNorm, maxVars: approach.MaxNumBases); var coefs = RebiasCoefs(candidateCoeffsNorm, interceptNorm, X_avgs, X_stds, y_avg, y_std, out var intercept); // create models out of the learned coefficients var ffxModels = GetModelsFromCoeffs(coefs, intercept, funcsArr, approach); // one last LS-optimization step on the training data foreach (var ffxModel in ffxModels) { if (ffxModel.NumBases > 0) ffxModel.OptimizeCoefficients(data); } return ffxModels; } private static double[,] RebiasCoefs(double[,] unbiasedCoefs, double[] unbiasedIntercepts, double [] X_avgs, double[] X_stds, double y_avg, double y_std, out double[] rebiasedIntercepts) { var rows = unbiasedIntercepts.Length; var cols = X_stds.Length; var rebiasedCoefs = new double[rows,cols]; rebiasedIntercepts = new double[rows]; for (int i = 0; i < rows; i++) { var unbiasedIntercept = unbiasedIntercepts[i]; rebiasedIntercepts[i] = unbiasedIntercept * y_std + y_avg; for (int j = 0; j < cols; j++) { rebiasedCoefs[i, j] = unbiasedCoefs[i, j] * y_std / X_stds[j]; rebiasedIntercepts[i] -= rebiasedCoefs[i, j] * X_avgs[j]; } } return rebiasedCoefs; } // finds all models with unique combinations of basis functions private static IEnumerable GetModelsFromCoeffs(double[,] candidateCoeffs, double[] intercept, IBasisFunction[] funcsArr, Approach approach) { List ffxModels = new List(); for (int i = 0; i < intercept.Length; i++) { var row = candidateCoeffs.GetRow(i); var nonzeroIndices = row.FindAllIndices(val => val != 0).ToArray(); if (nonzeroIndices.Count() > approach.MaxNumBases) continue; // ignore duplicate models (models with same combination of basis functions) var ffxModel = new FFXModel(intercept[i], nonzeroIndices.Select(idx => (row[idx], funcsArr[idx]))); ffxModels.Add(ffxModel); } return ffxModels; } private static IEnumerable CreateApproaches(bool interactions, bool denominator, bool exponentiations, bool nonlinearFuncs, bool hingeFunctions, int maxNumBases, double penalty) { var approaches = new List(); var valids = new bool[5] { interactions, denominator, exponentiations, nonlinearFuncs, hingeFunctions }; // return true if ALL indices of true values of arr1 also have true values in arr2 bool follows(BitArray arr1, bool[] arr2) { if (arr1.Length != arr2.Length) throw new ArgumentException("invalid lengths"); for (int i = 0; i < arr1.Length; i++) { if (arr1[i] && !arr2[i]) return false; } return true; } for (int i = 0; i < 32; i++) { // Iterate all combinations of 5 bools. // map i to a bool array of length 5 var v = i; int b = 0; var arr = new BitArray(5); var popCount = 0; while (v>0) { if (v % 2 == 1) { arr[b++] = true; popCount++; } ; v /= 2; } if (!follows(arr, valids)) continue; if (popCount >= 4) continue; // not too many features at once if (arr[0] && arr[2]) continue; // never need both exponent and inter approaches.Add(new Approach(arr[0], arr[1], arr[2], arr[3], arr[4], exponents, nonlinFuncs, maxNumBases, penalty, minHingeThr, maxHingeThr, numHingeThrs)); } return approaches; } } }