using System; using System.Collections.Generic; using System.Linq; using HEAL.Attic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Operators; using HeuristicLab.Optimization; using HeuristicLab.Parameters; namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { [Item("NodeImpactsReinitializationStrategyController", "")] [StorableType("C015C8C3-283D-4F51-A582-367587596709")] public class NodeImpactsReinitializationStrategyController : InstrumentedOperator, IReinitializationStrategyController { private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree"; private const string SymbolicExpressionTreeGrammarParameterName = "SymbolicExpressionTreeGrammar"; private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicDataAnalysisTreeInterpreter"; private const string ProblemDataParameterName = "ProblemData"; private const string EstimationLimitsParameterName = "EstimationLimits"; private const string MinimumFrequencyParameterName = "MinimumFrequency"; private const string LearningRateParameterName = "LearningRate"; #region Parameter Properties public IScopeTreeLookupParameter SymbolicExpressionTreeParameter { get { return (IScopeTreeLookupParameter)Parameters[SymbolicExpressionTreeParameterName]; } } public IValueLookupParameter SymbolicExpressionTreeGrammarParameter { get { return (IValueLookupParameter)Parameters[SymbolicExpressionTreeGrammarParameterName]; } } public ILookupParameter SymbolicDataAnalysisTreeInterpreterParameter { get { return (ILookupParameter)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; } } public ILookupParameter ProblemDataParameter { get { return (ILookupParameter)Parameters[ProblemDataParameterName]; } } public IValueLookupParameter EstimationLimitsParameter { get { return (IValueLookupParameter)Parameters[EstimationLimitsParameterName]; } } public IFixedValueParameter MinimumFrequencyParameter { get { return (IFixedValueParameter)Parameters[MinimumFrequencyParameterName]; } } public IFixedValueParameter LearningRateParameter { get { return (IFixedValueParameter)Parameters[LearningRateParameterName]; } } #endregion #region Properties public double MinimumFrequency { get { return MinimumFrequencyParameter.Value.Value; } set { MinimumFrequencyParameter.Value.Value = value; } } public double LearningRate { get { return LearningRateParameter.Value.Value; } set { LearningRateParameter.Value.Value = value; } } #endregion #region Constructors public NodeImpactsReinitializationStrategyController() { Parameters.Add(new ScopeTreeLookupParameter(SymbolicExpressionTreeParameterName, "The symbolic expression tree whose length should be calculated.")); Parameters.Add(new ValueLookupParameter(SymbolicExpressionTreeGrammarParameterName, "The tree grammar that defines the correct syntax of symbolic expression trees that should be created.")); Parameters.Add(new LookupParameter(SymbolicDataAnalysisTreeInterpreterParameterName, "The symbolic data analysis tree interpreter for the symbolic expression tree.")); Parameters.Add(new LookupParameter(ProblemDataParameterName, "The problem data for the symbolic regression solution.")); Parameters.Add(new ValueLookupParameter(EstimationLimitsParameterName, "The lower and upper limit for the estimated values produced by the symbolic regression model.")); Parameters.Add(new FixedValueParameter(MinimumFrequencyParameterName, "Minimum Frequency for the controller to set to symbols.", new DoubleValue(0))); Parameters.Add(new FixedValueParameter(LearningRateParameterName, "Learning Rate for how fast the frequency is adapted. Zero learning rate means no adaption, one means new frequency = symbol frequency.", new DoubleValue(0.1))); } private NodeImpactsReinitializationStrategyController(NodeImpactsReinitializationStrategyController original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new NodeImpactsReinitializationStrategyController(this, cloner); } [StorableConstructor] private NodeImpactsReinitializationStrategyController(StorableConstructorFlag _) : base(_) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { if (!Parameters.ContainsKey(MinimumFrequencyParameterName)) Parameters.Add(new FixedValueParameter(MinimumFrequencyParameterName, "Minimum Frequency for the controller to set to symbols.", new DoubleValue(0))); if (!Parameters.ContainsKey(LearningRateParameterName)) Parameters.Add(new FixedValueParameter(LearningRateParameterName, "Learning Rate for how fast the frequency is adapted. Zero learning rate means no adaption, one means new frequency = symbol frequency.", new DoubleValue(0.1))); } #endregion public override IOperation InstrumentedApply() { var trees = SymbolicExpressionTreeParameter.ActualValue; var grammar = SymbolicExpressionTreeGrammarParameter.ActualValue; var pd = ProblemDataParameter.ActualValue; var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue; double minimumFrequency = MinimumFrequency; double learningRate = LearningRate; if (interpreter == null) { interpreter = new SymbolicDataAnalysisExpressionTreeBatchInterpreter(); } var estimationLimits = EstimationLimitsParameter.ActualValue; var ds = ((Dataset)pd.Dataset).ToModifiable(); var symbolImpacts = new Dictionary(); var symbolCounts = new Dictionary(); var impactValuesCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); foreach (var tree in trees) { var model = new SymbolicRegressionModel(pd.TargetVariable, tree, interpreter, estimationLimits.Lower, estimationLimits.Upper); var root = tree.Root; while (root.Symbol is ProgramRootSymbol || root.Symbol is StartSymbol) { root = root.GetSubtree(0); } foreach(var node in root.IterateNodesPrefix().Where(x => x.SubtreeCount > 0)) { impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, pd, pd.TrainingIndices, out double impactValue, out double replacementValue, out double newQuality); var name = node.Symbol.Name; if (symbolCounts.TryGetValue(name, out int count)) { symbolCounts[name] = count + 1; symbolImpacts[name] += impactValue; } else { symbolCounts[name] = 1; symbolImpacts[name] = impactValue; } } } foreach(var symbol in grammar.AllowedSymbols) { if (symbolImpacts.TryGetValue(symbol.Name, out double impact)) { var f = Math.Max(0, impact / symbolCounts[symbol.Name]); // do something clever here double oldFrequency = symbol.InitialFrequency; double impactFrequency = f; double newFrequency = oldFrequency + (impactFrequency - oldFrequency) * learningRate; symbol.InitialFrequency = Math.Max(minimumFrequency, newFrequency); } } return base.InstrumentedApply(); } } }