using HeuristicLab.Algorithms.DataAnalysis.Glmnet; using HeuristicLab.Common; using HeuristicLab.Data; using HeuristicLab.Problems.DataAnalysis; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("UnitTests")] namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction { // utility functions for creating Basis Functions internal static class BFUtils { public static IEnumerable CreateBasisFunctions(IRegressionProblemData data, Approach approach) { var exponents = approach.AllowExp ? approach.Exponents : new HashSet { 1 }; var funcs = approach.AllowNonLinearFunctions ? approach.NonLinearFunctions : new HashSet { NonlinearOperator.None }; var simpleBasisFuncs = CreateSimpleBases(data, exponents, funcs); if (approach.AllowHinge) { // only allow hinge functions for features with exponent = 1 (deemed too complex otherwise) var linearSimpleBasisFuncs = simpleBasisFuncs.Where(simpleBf => simpleBf.Exponent == 1 && simpleBf.Operator.Equals(NonlinearOperator.None)); simpleBasisFuncs = simpleBasisFuncs.Concat(CreateHingeBases(data, linearSimpleBasisFuncs, approach.MinHingeThreshold, approach.MaxHingeThreshold, approach.NumHingeThresholds)); } IEnumerable functions = simpleBasisFuncs; if (approach.AllowInteractions) { var multivariateBases = CreateMultivariateBases(data, simpleBasisFuncs.ToArray()); functions = functions.Concat(multivariateBases); } if (approach.AllowDenominators) { var denominatorBases = CreateDenominatorBases(functions); functions = functions.Concat(denominatorBases); } return functions; } public static IEnumerable CreateSimpleBases(IRegressionProblemData problemData, HashSet exponents, HashSet nonLinearFunctions) { var simpleBasisFunctions = new List(); foreach (var variableName in problemData.AllowedInputVariables) { var vals = problemData.Dataset.GetDoubleValues(variableName).ToArray(); var min = vals.Min(); foreach (var exp in exponents) { var simpleBase = new SimpleBasisFunction(variableName, exp, NonlinearOperator.None); // if the basis function is not valid without any operator, then it won't be valid in combination with any nonlinear operator -> skip if (!Ok(simpleBase.Evaluate(problemData))) continue; foreach (NonlinearOperator op in nonLinearFunctions) { // ignore cases where op has no effect if (op.Equals(NonlinearOperator.Abs) && new[] { -2.0, 2.0 }.Contains(exp) && nonLinearFunctions.Contains(NonlinearOperator.None)) continue; if (op.Equals(NonlinearOperator.Abs) && min >= 0) continue; var nonsimpleBase = (SimpleBasisFunction)simpleBase.DeepCopy(); nonsimpleBase.Operator = op; if (!Ok(nonsimpleBase.Evaluate(problemData))) continue; simpleBasisFunctions.Add(nonsimpleBase); } } } return simpleBasisFunctions; } public static IEnumerable CreateMultivariateBases(IRegressionProblemData data, IList univariateBases) { var orderedFuncs = OrderBasisFuncsByImportance(data, univariateBases).ToArray(); var multivariateBases = new List(); int maxSize = 2 * orderedFuncs.Length; foreach (var bf in orderedFuncs) { // disallow bases with exponents if (bf.Exponent != 1) continue; multivariateBases.Add(new ProductBaseFunction(bf, bf, true)); } for (int i = 0; i < orderedFuncs.Count(); i++) { var b_i = orderedFuncs.ElementAt(i); for (int j = 0; j < i; j++) { var b_j = orderedFuncs.ElementAt(j); if (b_j.Operator != NonlinearOperator.None) continue; // disallow op() * op(); deemed to complex var b_inter = new ProductBaseFunction(b_i, b_j, true); if (!Ok(b_inter.Evaluate(data))) continue; multivariateBases.Add(b_inter); if (multivariateBases.Count() >= maxSize) return multivariateBases; } } return multivariateBases; } // order basis functions by importance (decr) // the importance of a basis function is measured by the absolute value of its coefficient when optimized on the data public static IEnumerable OrderBasisFuncsByImportance(IRegressionProblemData data, IList candidateFunctions) { var elnetData = PrepareData(Normalize(data, out _, out _, out _, out _), candidateFunctions); var coeff = ElasticNetLinearRegression.CalculateModelCoefficients(elnetData, 0, 0, out var trainNMSE, out var testNMSE); // LS-fit var intercept = coeff.Last(); coeff = coeff.Take(coeff.Length - 1).ToArray(); var order = Utils.Argsort(coeff); Array.Reverse(order); return order.Select(idx => candidateFunctions[idx]); } public static IList CreateHingeBases(IRegressionProblemData data, IEnumerable simple_bfs, double relative_start_thr = 0.0, double relative_end_thr = 1.0, int num_thrs = 3, IntRange trainingPartition = null) { var hingeBases = new List(); foreach (var simple_bf in simple_bfs) { hingeBases.AddRange(CreateHingeBases(data, simple_bf, relative_start_thr, relative_end_thr, num_thrs, trainingPartition)); } return hingeBases; } private static IEnumerable CreateHingeBases(IRegressionProblemData data, ISimpleBasisFunction simple_bf, double relative_start_thr, double relative_end_thr, int num_thrs, IntRange trainingPartition) { if (relative_start_thr >= relative_end_thr) throw new ArgumentException($"{nameof(relative_start_thr)} must be smaller than {nameof(relative_end_thr)}."); var ans = new List(); var vals = simple_bf.Evaluate(data); var temp = trainingPartition ?? data.TrainingPartition; double min = Double.MaxValue; double max = Double.MinValue; for (int i = temp.Start; i < temp.End; i++) { min = Math.Min(min, vals[i]); max = Math.Max(max, vals[i]); } if (max - min == 0) return ans; var full_range = max - min; var start_thr = min + relative_start_thr * full_range; var end_thr = min + relative_end_thr * full_range; var thresholds = Utils.Linspace(start_thr, end_thr, num_thrs); foreach (var thr in thresholds) { ans.Add(new SimpleBasisFunction(simple_bf.Feature, 1, NonlinearOperator.GT_Hinge, true, thr)); ans.Add(new SimpleBasisFunction(simple_bf.Feature, 1, NonlinearOperator.LT_Hinge, true, thr)); } return ans; } public static IEnumerable CreateDenominatorBases(IEnumerable basisFunctions) { List ans = new List(); foreach (var bf in basisFunctions) { if (!bf.IsDenominator) continue; var denomFunc = bf.DeepCopy(); denomFunc.IsDenominator = false; ans.Add(denomFunc); } return ans; } public static IRegressionProblemData PrepareData(IRegressionProblemData problemData, IEnumerable basisFunctions) { int numRows = problemData.Dataset.Rows; int numCols = basisFunctions.Count(); HashSet allowedInputVars = new HashSet(); double[,] variableValues = new double[numRows, numCols + 1]; // +1 for target var int col = 0; foreach (var basisFunc in basisFunctions) { allowedInputVars.Add(basisFunc.ToString() + (!basisFunc.IsDenominator ? " * " + problemData.TargetVariable : "")); var vals = basisFunc.Evaluate(problemData); for (int i = 0; i < numRows; i++) { variableValues[i, col] = vals[i]; } col++; } // add the unmodified target variable to the dataset var allVariables = new HashSet(allowedInputVars); allVariables.Add(problemData.TargetVariable); var targetVals = problemData.TargetVariableValues.ToArray(); for (int i = 0; i < numRows; i++) { variableValues[i, col] = targetVals[i]; } var temp = new Dataset(allVariables, variableValues); IRegressionProblemData rpd = new RegressionProblemData(temp, allowedInputVars, problemData.TargetVariable); rpd.TrainingPartition.Start = problemData.TrainingPartition.Start; rpd.TrainingPartition.End = problemData.TrainingPartition.End; rpd.TestPartition.Start = problemData.TestPartition.Start; rpd.TestPartition.End = problemData.TestPartition.End; return rpd; } public static IRegressionProblemData Normalize(IRegressionProblemData data, out double[] X_avgs, out double[] X_stds, out double y_avg, out double y_std) { X_avgs = data.AllowedInputVariables .Select(varname => data.Dataset.GetDoubleValues(varname) .Average()) .ToArray(); X_stds = data.AllowedInputVariables .Select(varname => data.Dataset.GetDoubleValues(varname) .StandardDeviationPop()) .ToArray(); for (int i = 0; i < X_stds.Length; i++) { if (X_stds[i] == 0) X_stds[i] = 1; } y_avg = data.TargetVariableValues.Average(); y_std = data.TargetVariableValues.StandardDeviationPop(); if (y_std == 0) y_std = 1; var temp = Normalize(data.Dataset); var ans = new RegressionProblemData(Normalize(data.Dataset), data.AllowedInputVariables, data.TargetVariable); return ans; } // return a normalized version of IDataset ds private static IDataset Normalize(IDataset ds) { var doubleNames = ds.DoubleVariables.ToArray(); if (ds.VariableNames.Count() != doubleNames.Length) throw new ArgumentException(nameof(ds)); var variableVals = new List>(); foreach (var name in doubleNames) { var vals = Utils.Normalize(ds.GetDoubleValues(name).ToArray()); variableVals.Add(vals.ToList()); } return new Dataset(doubleNames, variableVals); } private static bool Ok(IEnumerable data) => data.All(x => !double.IsNaN(x) && !double.IsInfinity(x)); } }