#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion #if INCLUDE_DIFFSHARP using System; using System.Collections.Generic; using System.Linq; using System.Runtime.Serialization; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using DiffSharp.Interop.Float64; using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { public class TreeToDiffSharpConverter { public delegate D ParametricFunction(D[] vars, D[] @params, DV[] vectorParams); #region helper class public class DataForVariable { public readonly string variableName; public readonly string variableValue; // for factor vars public readonly int lag; public DataForVariable(string varName, string varValue, int lag) { this.variableName = varName; this.variableValue = varValue; this.lag = lag; } public override bool Equals(object obj) { if (obj is DataForVariable other) { return other.variableName.Equals(this.variableName) && other.variableValue.Equals(this.variableValue) && other.lag == this.lag; } return false; } public override int GetHashCode() { return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag; } } public class EvaluationResult { public D Scalar { get; } public bool IsScalar => !ReferenceEquals(Scalar, NanScalar); public DV Vector { get; } public bool IsVector => !ReferenceEquals(Vector, NaNVector); public bool IsNaN => !IsScalar && !IsVector; public EvaluationResult(D scalar) { if (scalar == null) throw new ArgumentNullException(nameof(scalar)); Scalar = scalar; Vector = NaNVector; } public EvaluationResult(DV vector) { if (vector == null) throw new ArgumentNullException(nameof(vector)); Scalar = NanScalar; Vector = vector; } private EvaluationResult() { Scalar = NanScalar; Vector = NaNVector; } private static readonly DV NaNVector = new DV(new[] { double.NaN }); private static readonly D NanScalar = new D(double.NaN); public static readonly EvaluationResult NaN = new EvaluationResult(); } #endregion public static D Evaluate(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, DV variables, IDictionary scalarParameters, IDictionary vectorsParameters) { var transformator = new TreeToDiffSharpConverter( variables, scalarParameters, vectorsParameters, makeVariableWeightsVariable, addLinearScalingTerms); var result = transformator.ConvertNode(tree.Root.GetSubtree(0)); if (!result.IsScalar) throw new InvalidOperationException("Result of evaluation is not a scalar."); return result.Scalar; } //public static bool TryConvert(ISymbolicExpressionTree tree, IDataset dataset, // bool makeVariableWeightsVariable, bool addLinearScalingTerms, // out double[] initialConstants, out List scalarParameters, out List vectorParameters, // out D func) { // var transformator = new TreeToDiffSharpConverter(dataset, makeVariableWeightsVariable, addLinearScalingTerms); // try { // D term = transformator.ConvertNode(tree.Root.GetSubtree(0)); // initialConstants = transformator.initialConstants.ToArray(); // var scalarParameterEntries = transformator.scalarParameters.ToArray(); // guarantee same order for keys and values // var vectorParameterEntries = transformator.vectorParameters.ToArray(); // guarantee same order for keys and values // scalarParameters = scalarParameterEntries.Select(kvp => kvp.Key).ToList(); // vectorParameters = vectorParameterEntries.Select(kvp => kvp.Key).ToList(); // func = term; // return true; // } catch (ConversionException) { // initialConstants = null; // scalarParameters = null; // vectorParameters = null; // func = null; // } // return false; //} public static List GetInitialConstants(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, IDictionary scalarParameters, IDictionary vectorsParameters) { var transformator = new TreeToDiffSharpConverter( /*dataset, */ null, scalarParameters, vectorsParameters, makeVariableWeightsVariable, addLinearScalingTerms); transformator.ConvertNode(tree.Root.GetSubtree(0)); return transformator.initialConstants; } /*private readonly IDataset dataset;*/ private readonly IDictionary scalarParameters; private readonly IDictionary vectorsParameters; private readonly bool makeVariableWeightsVariable; private readonly bool addLinearScalingTerms; private readonly List initialConstants; private readonly DV variables; private int variableIdx; //private readonly Dictionary scalarParameters; //private readonly Dictionary vectorParameters; private TreeToDiffSharpConverter(/*IDataset dataset,*/ DV variables, IDictionary scalarParameters, IDictionary vectorsParameters, bool makeVariableWeightsVariable, bool addLinearScalingTerms) { /*this.dataset = dataset;*/ this.scalarParameters = scalarParameters; this.vectorsParameters = vectorsParameters; this.makeVariableWeightsVariable = makeVariableWeightsVariable; this.addLinearScalingTerms = addLinearScalingTerms; initialConstants = new List(); this.variables = variables; variableIdx = 0; //scalarParameters = new Dictionary(); //vectorParameters = new Dictionary(); } #region Evaluation helpers private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs, /*Func lengthStrategy,*/ Func ssFunc = null, Func svFunc = null, Func vsFunc = null, Func vvFunc = null) { if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar)); if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector)); if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar)); if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector)); /* } if (lhs.Vector.Count == rhs.Vector.Count) { return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector)); } else { var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector); return new EvaluationResult(vvFunc(lhsVector, rhsVector)); }*/ return EvaluationResult.NaN; } private static EvaluationResult FunctionApply(EvaluationResult val, Func sFunc = null, Func vFunc = null) { if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar)); if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector)); return EvaluationResult.NaN; } private static EvaluationResult AggregateApply(EvaluationResult val, Func sFunc = null, Func vFunc = null) { if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar)); if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector)); return EvaluationResult.NaN; } #endregion private EvaluationResult ConvertNode(ISymbolicExpressionTreeNode node) { if (node.Symbol is Constant) { // assume scalar constant var constant = ((ConstantTreeNode)node).Value; initialConstants.Add(constant); var c = variables?[variableIdx++] ?? constant; return new EvaluationResult(c); } if (node.Symbol is Variable) { var varNode = node as VariableTreeNodeBase; if (scalarParameters.ContainsKey(varNode.VariableName)) { var par = scalarParameters[varNode.VariableName]; if (makeVariableWeightsVariable) { var weight = varNode.Weight; initialConstants.Add(weight); var w = variables?[variableIdx++] ?? weight; return new EvaluationResult(w * par); } else { return new EvaluationResult(varNode.Weight * par); } } else if (vectorsParameters.ContainsKey(varNode.VariableName)) { var par = vectorsParameters[varNode.VariableName]; if (makeVariableWeightsVariable) { var weight = varNode.Weight; initialConstants.Add(weight); var w = variables?[variableIdx++] ?? weight; return new EvaluationResult(w * par); } else { return new EvaluationResult(varNode.Weight * par); } } } //if (node.Symbol is FactorVariable) { // var factorVarNode = node as FactorVariableTreeNode; // var products = new List(); // foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { // var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); // var wVar = new D(factorVarNode.GetValue(variableValue)); // variables.Add(wVar); // products.Add(wVar * par); // } // return products.Aggregate((x, y) => x + y); //} if (node.Symbol is Addition) { var terms = node.Subtrees.Select(ConvertNode).ToList(); return terms.Aggregate((a, b) => ArithmeticApply(a, b, (s1, s2) => s1 + s2, (s1, v2) => s1 + v2, (v1, s2) => v1 + s2, (v1, v2) => v1 + v2 )); } if (node.Symbol is Subtraction) { var terms = node.Subtrees.Select(ConvertNode).ToList(); if (terms.Count == 1) return FunctionApply(terms[0], s => D.Neg(s), v => DV.Neg(v)); return terms.Aggregate((a, b) => ArithmeticApply(a, b, (s1, s2) => s1 - s2, (s1, v2) => s1 - v2, (v1, s2) => v1 - s2, (v1, v2) => v1 - v2 )); } if (node.Symbol is Multiplication) { var terms = node.Subtrees.Select(ConvertNode).ToList(); return terms.Aggregate((a, b) => ArithmeticApply(a, b, (s1, s2) => s1 * s2, (s1, v2) => s1 * v2, (v1, s2) => v1 * s2, (v1, v2) => DV.PointwiseMultiply(v1, v2) )); } if (node.Symbol is Division) { var terms = node.Subtrees.Select(ConvertNode).ToList(); if (terms.Count == 1) return FunctionApply(terms[0], s => 1.0 / s, v => 1.0 / v); return terms.Aggregate((a, b) => ArithmeticApply(a, b, (s1, s2) => s1 / s2, (s1, v2) => s1 / v2, (v1, s2) => v1 / s2, (v1, v2) => DV.PointwiseDivision(v1, v2) )); } if (node.Symbol is Absolute) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Abs(s), v => DV.Abs(v) ); } if (node.Symbol is Logarithm) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Log(s), v => DV.Log(v) ); } if (node.Symbol is Exponential) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Pow(Math.E, s), v => DV.Pow(Math.E, v) ); } if (node.Symbol is Square) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Pow(s, 2), v => DV.Pow(v, 2) ); } if (node.Symbol is SquareRoot) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Sqrt(s), v => DV.Sqrt(v) ); } if (node.Symbol is Cube) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Pow(s, 3), v => DV.Pow(v, 3) ); } if (node.Symbol is CubeRoot) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Sign(s) * D.Pow(D.Abs(s), 1.0 / 3.0), v => DV.PointwiseMultiply(DV.Sign(v), DV.Pow(DV.Abs(v), 1.0 / 3.0)) ); } if (node.Symbol is Sine) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Sin(s), v => DV.Sin(v) ); } if (node.Symbol is Cosine) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Cos(s), v => DV.Cos(v) ); } if (node.Symbol is Tangent) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Tan(s), v => DV.Tan(v) ); } if (node.Symbol is HyperbolicTangent) { return FunctionApply(ConvertNode(node.GetSubtree(0)), s => D.Tanh(s), v => DV.Tanh(v) ); } if (node.Symbol is Sum) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => s, v => DV.Sum(v) ); } if (node.Symbol is Mean) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => s, v => DV.Mean(v) ); } if (node.Symbol is StandardDeviation) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => 0, v => DV.StandardDev(v) //TODO: use pop-stdev instead ); } if (node.Symbol is Length) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => 1, v => DV.Sum(v) / DV.Mean(v) // TODO: no length? ); } if (node.Symbol is Min) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => s, v => DV.Min(v) ); } if (node.Symbol is Max) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => s, v => DV.Max(v) ); } if (node.Symbol is Variance) { return AggregateApply(ConvertNode(node.GetSubtree(0)), s => s, v => DV.Variance(v) ); } //if (node.Symbol is Skewness) { //} //if (node.Symbol is Kurtosis) { //} //if (node.Symbol is EuclideanDistance) { //} //if (node.Symbol is Covariance) { //} if (node.Symbol is StartSymbol) { if (addLinearScalingTerms) { // scaling variables α, β are given at the beginning of the parameter vector initialConstants.Add(0.0); initialConstants.Add(1.0); var beta = variables?[variableIdx++] ?? 0.0; var alpha = variables?[variableIdx++] ?? 1.0; var t = ConvertNode(node.GetSubtree(0)); if (!t.IsScalar) throw new InvalidOperationException("Must be a scalar result"); return new EvaluationResult(t.Scalar * alpha + beta); } else return ConvertNode(node.GetSubtree(0)); } throw new ConversionException(); } public static bool IsCompatible(ISymbolicExpressionTree tree) { var containsUnknownSymbol = ( from n in tree.Root.GetSubtree(0).IterateNodesPrefix() where !(n.Symbol is Variable) && //!(n.Symbol is BinaryFactorVariable) && //!(n.Symbol is FactorVariable) && //!(n.Symbol is LaggedVariable) && !(n.Symbol is Constant) && !(n.Symbol is Addition) && !(n.Symbol is Subtraction) && !(n.Symbol is Multiplication) && !(n.Symbol is Division) && !(n.Symbol is Logarithm) && !(n.Symbol is Exponential) && !(n.Symbol is SquareRoot) && !(n.Symbol is Square) && !(n.Symbol is Sine) && !(n.Symbol is Cosine) && !(n.Symbol is Tangent) && !(n.Symbol is HyperbolicTangent) && //!(n.Symbol is Erf) && //!(n.Symbol is Norm) && !(n.Symbol is StartSymbol) && !(n.Symbol is Absolute) && //!(n.Symbol is AnalyticQuotient) && !(n.Symbol is Cube) && !(n.Symbol is CubeRoot) && !(n.Symbol is Sum) && !(n.Symbol is Mean) && !(n.Symbol is StandardDeviation) && //!(n.Symbol is Length) && !(n.Symbol is Min) && !(n.Symbol is Max) && !(n.Symbol is Variance) //!(n.Symbol is Skewness) && //!(n.Symbol is Kurtosis) && //!(n.Symbol is EuclideanDistance) && //!(n.Symbol is Covariance) select n).Any(); return !containsUnknownSymbol; } #region exception class [Serializable] public class ConversionException : Exception { public ConversionException() { } public ConversionException(string message) : base(message) { } public ConversionException(string message, Exception inner) : base(message, inner) { } protected ConversionException( SerializationInfo info, StreamingContext context) : base(info, context) { } } #endregion } } #endif