#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
#if INCLUDE_DIFFSHARP
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using DiffSharp.Interop.Float64;
using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
public class TreeToDiffSharpConverter {
public delegate D ParametricFunction(D[] vars, D[] @params, DV[] vectorParams);
#region helper class
public class DataForVariable {
public readonly string variableName;
public readonly string variableValue; // for factor vars
public readonly int lag;
public DataForVariable(string varName, string varValue, int lag) {
this.variableName = varName;
this.variableValue = varValue;
this.lag = lag;
}
public override bool Equals(object obj) {
if (obj is DataForVariable other) {
return other.variableName.Equals(this.variableName) &&
other.variableValue.Equals(this.variableValue) &&
other.lag == this.lag;
}
return false;
}
public override int GetHashCode() {
return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
}
}
public class EvaluationResult {
public D Scalar { get; }
public bool IsScalar => !ReferenceEquals(Scalar, NanScalar);
public DV Vector { get; }
public bool IsVector => !ReferenceEquals(Vector, NaNVector);
public bool IsNaN => !IsScalar && !IsVector;
public EvaluationResult(D scalar) {
if (scalar == null) throw new ArgumentNullException(nameof(scalar));
Scalar = scalar;
Vector = NaNVector;
}
public EvaluationResult(DV vector) {
if (vector == null) throw new ArgumentNullException(nameof(vector));
Scalar = NanScalar;
Vector = vector;
}
private EvaluationResult() {
Scalar = NanScalar;
Vector = NaNVector;
}
private static readonly DV NaNVector = new DV(new[] { double.NaN });
private static readonly D NanScalar = new D(double.NaN);
public static readonly EvaluationResult NaN = new EvaluationResult();
}
#endregion
public static D Evaluate(ISymbolicExpressionTree tree,
bool makeVariableWeightsVariable, bool addLinearScalingTerms,
DV variables,
IDictionary scalarParameters, IDictionary vectorsParameters) {
var transformator = new TreeToDiffSharpConverter(
variables,
scalarParameters, vectorsParameters,
makeVariableWeightsVariable, addLinearScalingTerms);
var result = transformator.ConvertNode(tree.Root.GetSubtree(0));
if (!result.IsScalar) throw new InvalidOperationException("Result of evaluation is not a scalar.");
return result.Scalar;
}
//public static bool TryConvert(ISymbolicExpressionTree tree, IDataset dataset,
// bool makeVariableWeightsVariable, bool addLinearScalingTerms,
// out double[] initialConstants, out List scalarParameters, out List vectorParameters,
// out D func) {
// var transformator = new TreeToDiffSharpConverter(dataset, makeVariableWeightsVariable, addLinearScalingTerms);
// try {
// D term = transformator.ConvertNode(tree.Root.GetSubtree(0));
// initialConstants = transformator.initialConstants.ToArray();
// var scalarParameterEntries = transformator.scalarParameters.ToArray(); // guarantee same order for keys and values
// var vectorParameterEntries = transformator.vectorParameters.ToArray(); // guarantee same order for keys and values
// scalarParameters = scalarParameterEntries.Select(kvp => kvp.Key).ToList();
// vectorParameters = vectorParameterEntries.Select(kvp => kvp.Key).ToList();
// func = term;
// return true;
// } catch (ConversionException) {
// initialConstants = null;
// scalarParameters = null;
// vectorParameters = null;
// func = null;
// }
// return false;
//}
public static List GetInitialConstants(ISymbolicExpressionTree tree,
bool makeVariableWeightsVariable, bool addLinearScalingTerms,
IDictionary scalarParameters, IDictionary vectorsParameters) {
var transformator = new TreeToDiffSharpConverter( /*dataset, */
null,
scalarParameters, vectorsParameters,
makeVariableWeightsVariable, addLinearScalingTerms);
transformator.ConvertNode(tree.Root.GetSubtree(0));
return transformator.initialConstants;
}
/*private readonly IDataset dataset;*/
private readonly IDictionary scalarParameters;
private readonly IDictionary vectorsParameters;
private readonly bool makeVariableWeightsVariable;
private readonly bool addLinearScalingTerms;
private readonly List initialConstants;
private readonly DV variables;
private int variableIdx;
//private readonly Dictionary scalarParameters;
//private readonly Dictionary vectorParameters;
private TreeToDiffSharpConverter(/*IDataset dataset,*/
DV variables,
IDictionary scalarParameters, IDictionary vectorsParameters,
bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
/*this.dataset = dataset;*/
this.scalarParameters = scalarParameters;
this.vectorsParameters = vectorsParameters;
this.makeVariableWeightsVariable = makeVariableWeightsVariable;
this.addLinearScalingTerms = addLinearScalingTerms;
initialConstants = new List();
this.variables = variables;
variableIdx = 0;
//scalarParameters = new Dictionary();
//vectorParameters = new Dictionary();
}
#region Evaluation helpers
private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
/*Func lengthStrategy,*/
Func ssFunc = null,
Func svFunc = null,
Func vsFunc = null,
Func vvFunc = null) {
if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
/* }
if (lhs.Vector.Count == rhs.Vector.Count) {
return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
} else {
var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
return new EvaluationResult(vvFunc(lhsVector, rhsVector));
}*/
return EvaluationResult.NaN;
}
private static EvaluationResult FunctionApply(EvaluationResult val,
Func sFunc = null,
Func vFunc = null) {
if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
return EvaluationResult.NaN;
}
private static EvaluationResult AggregateApply(EvaluationResult val,
Func sFunc = null,
Func vFunc = null) {
if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
return EvaluationResult.NaN;
}
#endregion
private EvaluationResult ConvertNode(ISymbolicExpressionTreeNode node) {
if (node.Symbol is Constant) {
// assume scalar constant
var constant = ((ConstantTreeNode)node).Value;
initialConstants.Add(constant);
var c = variables?[variableIdx++] ?? constant;
return new EvaluationResult(c);
}
if (node.Symbol is Variable) {
var varNode = node as VariableTreeNodeBase;
if (scalarParameters.ContainsKey(varNode.VariableName)) {
var par = scalarParameters[varNode.VariableName];
if (makeVariableWeightsVariable) {
var weight = varNode.Weight;
initialConstants.Add(weight);
var w = variables?[variableIdx++] ?? weight;
return new EvaluationResult(w * par);
} else {
return new EvaluationResult(varNode.Weight * par);
}
} else if (vectorsParameters.ContainsKey(varNode.VariableName)) {
var par = vectorsParameters[varNode.VariableName];
if (makeVariableWeightsVariable) {
var weight = varNode.Weight;
initialConstants.Add(weight);
var w = variables?[variableIdx++] ?? weight;
return new EvaluationResult(w * par);
} else {
return new EvaluationResult(varNode.Weight * par);
}
}
}
//if (node.Symbol is FactorVariable) {
// var factorVarNode = node as FactorVariableTreeNode;
// var products = new List();
// foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
// var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
// var wVar = new D(factorVarNode.GetValue(variableValue));
// variables.Add(wVar);
// products.Add(wVar * par);
// }
// return products.Aggregate((x, y) => x + y);
//}
if (node.Symbol is Addition) {
var terms = node.Subtrees.Select(ConvertNode).ToList();
return terms.Aggregate((a, b) =>
ArithmeticApply(a, b,
(s1, s2) => s1 + s2,
(s1, v2) => s1 + v2,
(v1, s2) => v1 + s2,
(v1, v2) => v1 + v2
));
}
if (node.Symbol is Subtraction) {
var terms = node.Subtrees.Select(ConvertNode).ToList();
if (terms.Count == 1) return FunctionApply(terms[0],
s => D.Neg(s),
v => DV.Neg(v));
return terms.Aggregate((a, b) =>
ArithmeticApply(a, b,
(s1, s2) => s1 - s2,
(s1, v2) => s1 - v2,
(v1, s2) => v1 - s2,
(v1, v2) => v1 - v2
));
}
if (node.Symbol is Multiplication) {
var terms = node.Subtrees.Select(ConvertNode).ToList();
return terms.Aggregate((a, b) =>
ArithmeticApply(a, b,
(s1, s2) => s1 * s2,
(s1, v2) => s1 * v2,
(v1, s2) => v1 * s2,
(v1, v2) => DV.PointwiseMultiply(v1, v2)
));
}
if (node.Symbol is Division) {
var terms = node.Subtrees.Select(ConvertNode).ToList();
if (terms.Count == 1) return FunctionApply(terms[0],
s => 1.0 / s,
v => 1.0 / v);
return terms.Aggregate((a, b) =>
ArithmeticApply(a, b,
(s1, s2) => s1 / s2,
(s1, v2) => s1 / v2,
(v1, s2) => v1 / s2,
(v1, v2) => DV.PointwiseDivision(v1, v2)
));
}
if (node.Symbol is Absolute) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Abs(s),
v => DV.Abs(v)
);
}
if (node.Symbol is Logarithm) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Log(s),
v => DV.Log(v)
);
}
if (node.Symbol is Exponential) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Pow(Math.E, s),
v => DV.Pow(Math.E, v)
);
}
if (node.Symbol is Square) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Pow(s, 2),
v => DV.Pow(v, 2)
);
}
if (node.Symbol is SquareRoot) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Sqrt(s),
v => DV.Sqrt(v)
);
}
if (node.Symbol is Cube) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Pow(s, 3),
v => DV.Pow(v, 3)
);
}
if (node.Symbol is CubeRoot) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Sign(s) * D.Pow(D.Abs(s), 1.0 / 3.0),
v => DV.PointwiseMultiply(DV.Sign(v), DV.Pow(DV.Abs(v), 1.0 / 3.0))
);
}
if (node.Symbol is Sine) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Sin(s),
v => DV.Sin(v)
);
}
if (node.Symbol is Cosine) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Cos(s),
v => DV.Cos(v)
);
}
if (node.Symbol is Tangent) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Tan(s),
v => DV.Tan(v)
);
}
if (node.Symbol is HyperbolicTangent) {
return FunctionApply(ConvertNode(node.GetSubtree(0)),
s => D.Tanh(s),
v => DV.Tanh(v)
);
}
if (node.Symbol is Sum) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => s,
v => DV.Sum(v)
);
}
if (node.Symbol is Mean) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => s,
v => DV.Mean(v)
);
}
if (node.Symbol is StandardDeviation) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => 0,
v => DV.StandardDev(v) //TODO: use pop-stdev instead
);
}
if (node.Symbol is Length) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => 1,
v => DV.Sum(v) / DV.Mean(v) // TODO: no length?
);
}
if (node.Symbol is Min) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => s,
v => DV.Min(v)
);
}
if (node.Symbol is Max) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => s,
v => DV.Max(v)
);
}
if (node.Symbol is Variance) {
return AggregateApply(ConvertNode(node.GetSubtree(0)),
s => s,
v => DV.Variance(v)
);
}
//if (node.Symbol is Skewness) {
//}
//if (node.Symbol is Kurtosis) {
//}
//if (node.Symbol is EuclideanDistance) {
//}
//if (node.Symbol is Covariance) {
//}
if (node.Symbol is StartSymbol) {
if (addLinearScalingTerms) {
// scaling variables α, β are given at the beginning of the parameter vector
initialConstants.Add(0.0);
initialConstants.Add(1.0);
var beta = variables?[variableIdx++] ?? 0.0;
var alpha = variables?[variableIdx++] ?? 1.0;
var t = ConvertNode(node.GetSubtree(0));
if (!t.IsScalar) throw new InvalidOperationException("Must be a scalar result");
return new EvaluationResult(t.Scalar * alpha + beta);
} else return ConvertNode(node.GetSubtree(0));
}
throw new ConversionException();
}
public static bool IsCompatible(ISymbolicExpressionTree tree) {
var containsUnknownSymbol = (
from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
where
!(n.Symbol is Variable) &&
//!(n.Symbol is BinaryFactorVariable) &&
//!(n.Symbol is FactorVariable) &&
//!(n.Symbol is LaggedVariable) &&
!(n.Symbol is Constant) &&
!(n.Symbol is Addition) &&
!(n.Symbol is Subtraction) &&
!(n.Symbol is Multiplication) &&
!(n.Symbol is Division) &&
!(n.Symbol is Logarithm) &&
!(n.Symbol is Exponential) &&
!(n.Symbol is SquareRoot) &&
!(n.Symbol is Square) &&
!(n.Symbol is Sine) &&
!(n.Symbol is Cosine) &&
!(n.Symbol is Tangent) &&
!(n.Symbol is HyperbolicTangent) &&
//!(n.Symbol is Erf) &&
//!(n.Symbol is Norm) &&
!(n.Symbol is StartSymbol) &&
!(n.Symbol is Absolute) &&
//!(n.Symbol is AnalyticQuotient) &&
!(n.Symbol is Cube) &&
!(n.Symbol is CubeRoot) &&
!(n.Symbol is Sum) &&
!(n.Symbol is Mean) &&
!(n.Symbol is StandardDeviation) &&
//!(n.Symbol is Length) &&
!(n.Symbol is Min) &&
!(n.Symbol is Max) &&
!(n.Symbol is Variance)
//!(n.Symbol is Skewness) &&
//!(n.Symbol is Kurtosis) &&
//!(n.Symbol is EuclideanDistance) &&
//!(n.Symbol is Covariance)
select n).Any();
return !containsUnknownSymbol;
}
#region exception class
[Serializable]
public class ConversionException : Exception {
public ConversionException() {
}
public ConversionException(string message) : base(message) {
}
public ConversionException(string message, Exception inner) : base(message, inner) {
}
protected ConversionException(
SerializationInfo info,
StreamingContext context) : base(info, context) {
}
}
#endregion
}
}
#endif