using System; using System.Collections.Generic; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using Microsoft.VisualStudio.TestTools.UnitTesting; using System.Linq; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; namespace Tests { [TestClass] public class AutoDiffTestClass { [TestMethod] public void AutoDiffTest() { { // eval var parser = new InfixExpressionParser(); var t = parser.Parse("2.0*x+y"); // interval eval var evaluator = new IntervalEvaluator(); var intervals = new Dictionary(); intervals.Add("x", new Interval(-1.0, 1.0)); intervals.Add("y", new Interval(2.0, 10.0)); var resultInterval = evaluator.Evaluate(t, intervals); Assert.AreEqual(0, resultInterval.LowerBound); Assert.AreEqual(12, resultInterval.UpperBound); } { // vector eval var parser = new InfixExpressionParser(); var t = parser.Parse("2.0*x+y"); var evaluator = new VectorEvaluator(); var vars = new string[] { "x", "y", "f(x)" }; var values = new double[,] { { 1, 1, 0 }, { 2, 1, 0 }, { 3, -1, 0 }, { 4, -1, 0 }, { 5, -1, 0 }, }; var ds = new Dataset(vars, values); var problemData = new RegressionProblemData(ds, vars, "f(x)"); var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray()); Assert.AreEqual(2, train.Length); Assert.AreEqual(3, train[0]); Assert.AreEqual(5, train[1]); var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray()); Assert.AreEqual(3, test.Length); Assert.AreEqual(5, test[0]); Assert.AreEqual(7, test[1]); Assert.AreEqual(9, test[2]); } { // vector eval and auto-diff var parser = new InfixExpressionParser(); var t = parser.Parse("2.0*x+y"); var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode); var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y"); var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 }; var evaluator = new VectorAutoDiffEvaluator(); var vars = new string[] { "x", "y", "f(x)" }; var values = new double[,] { { 1, 1, 0 }, { 2, 1, 0 }, { 3, -1, 0 }, { 4, -1, 0 }, { 5, -1, 0 }, }; var ds = new Dataset(vars, values); var problemData = new RegressionProblemData(ds, vars, "f(x)"); var train = new double[problemData.TrainingIndices.Count()]; var trainJac = new double[train.Length, 2]; evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, train, trainJac); Assert.AreEqual(2, train.Length); Assert.AreEqual(3, train[0]); Assert.AreEqual(5, train[1]); // check jac Assert.AreEqual(1, trainJac[0, 0]); Assert.AreEqual(1, trainJac[0, 1]); Assert.AreEqual(2, trainJac[1, 0]); Assert.AreEqual(1, trainJac[1, 1]); var test = new double[problemData.TestIndices.Count()]; var testJac = new double[test.Length, 2]; evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, test, testJac); Assert.AreEqual(3, test.Length); Assert.AreEqual(5, test[0]); Assert.AreEqual(7, test[1]); Assert.AreEqual(9, test[2]); // check jac Assert.AreEqual(3, testJac[0, 0]); Assert.AreEqual(-1, testJac[0, 1]); Assert.AreEqual(4, testJac[1, 0]); Assert.AreEqual(-1, testJac[1, 1]); Assert.AreEqual(5, testJac[2, 0]); Assert.AreEqual(-1, testJac[2, 1]); } { // Interval tests var intervals = new Dictionary(); intervals.Add("x", new Interval(-2.0, 3.0)); intervals.Add("p", new Interval(1.0, 2.0)); intervals.Add("n", new Interval(-2.0, -1.0)); AssertInterval("10*x", intervals, -20, 30); AssertInterval("sqr(p)", intervals, 1, 4); AssertInterval("sqr(n)", intervals, 1, 4); AssertInterval("sqr(x)", intervals, 0, 9); AssertInterval("cube(p)", intervals, 1, 8); AssertInterval("cube(n)", intervals, -8, -1); AssertInterval("cube(x)", intervals, -8, 27); } { // interval eval and auto-diff var parser = new InfixExpressionParser(); var t = parser.Parse("2.0*x+y"); var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode); var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y"); var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 }; var evaluator = new IntervalEvaluator(); var intervals = new Dictionary(); intervals.Add("x", new Interval(-1.0, 1.0)); intervals.Add("y", new Interval(2.0, 10.0)); var resultInterval = evaluator.Evaluate(t, intervals, paramNodes, out double[] lowerGradient, out double[] upperGradient); Assert.AreEqual(0, resultInterval.LowerBound); Assert.AreEqual(12, resultInterval.UpperBound); Assert.AreEqual(-1, lowerGradient[0]); Assert.AreEqual(2, lowerGradient[1]); Assert.AreEqual(1, upperGradient[0]); Assert.AreEqual(10, upperGradient[1]); } { // as discussed with Fabrício var intervals = new Dictionary(); intervals.Add("x1", new Interval(60.0, 65.0)); intervals.Add("x2", new Interval(30.0, 40.0)); intervals.Add("x3", new Interval(5.0, 10.0)); intervals.Add("x4", new Interval(0.5, 0.8)); intervals.Add("x5", new Interval(0.2, 0.5)); var parser = new InfixExpressionParser(); var t1 = parser.Parse("x5/x4"); var t2 = parser.Parse("log(x5/x4)"); var t3 = parser.Parse("x3 * log(x5/x4)"); var t4 = parser.Parse("x1*x2*x5"); var t5 = parser.Parse("x4/x5"); var t6 = parser.Parse("sqr(x4/x5)"); var t7 = parser.Parse("(1 - sqr(x4/x5)) "); var t8 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5))"); var t9 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5)) + x3 * log(x5/x4)"); var evaluator = new IntervalEvaluator(); var result = evaluator.Evaluate(t1, intervals); Assert.AreEqual(0.25, result.LowerBound); Assert.AreEqual(1, result.UpperBound); result = evaluator.Evaluate(t2, intervals); Assert.AreEqual(-1.386294361, result.LowerBound, 1e-6); Assert.AreEqual(0, result.UpperBound); result = evaluator.Evaluate(t3, intervals); Assert.AreEqual(-13.86294361, result.LowerBound, 1e-6); Assert.AreEqual(0, result.UpperBound); result = evaluator.Evaluate(t4, intervals); Assert.AreEqual(360, result.LowerBound); Assert.AreEqual(1300, result.UpperBound); result = evaluator.Evaluate(t5, intervals); Assert.AreEqual(1, result.LowerBound, 1e-6); Assert.AreEqual(4, result.UpperBound); result = evaluator.Evaluate(t6, intervals); Assert.AreEqual(1, result.LowerBound); Assert.AreEqual(16, result.UpperBound); result = evaluator.Evaluate(t7, intervals); Assert.AreEqual(-15, result.LowerBound); Assert.AreEqual(0, result.UpperBound); result = evaluator.Evaluate(t8, intervals); Assert.AreEqual(-19500, result.LowerBound); Assert.AreEqual(0, result.UpperBound); result = evaluator.Evaluate(t9, intervals); Assert.AreEqual(-19513.86294, result.LowerBound, 1e-3); Assert.AreEqual(0, result.UpperBound); } { // derivatives and intervals for flow psi problem var intervals = new Dictionary(); intervals.Add("x1", new Interval(60.0, 65.0)); intervals.Add("x2", new Interval(30.0, 40.0)); intervals.Add("x3", new Interval(5.0, 10.0)); intervals.Add("x4", new Interval(0.5, 0.8)); intervals.Add("x5", new Interval(0.2, 0.5)); var parser = new InfixExpressionParser(); var formatter = new InfixExpressionFormatter(); var expr = parser.Parse("x1*x2*x5*(1 - sqr(x4/x5)) + x3 * log(x5/x4)"); var dfdx1 = DerivativeCalculator.Derive(expr, "x1"); Assert.AreEqual("('x2' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx1)); // x2 x5 (1 - sqr(x4/x5)) var dfdx2 = DerivativeCalculator.Derive(expr, "x2"); Assert.AreEqual("('x1' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx2)); // x1 x5 (1 - sqr(x4/x5)) var dfdx3 = DerivativeCalculator.Derive(expr, "x3"); Assert.AreEqual("LOG(('x5' / 'x4'))", formatter.Format(dfdx3)); // log(x5/x4) var dfdx4 = DerivativeCalculator.Derive(expr, "x4"); Assert.AreEqual("((('x1' * 'x2' * 'x5' * 'x4' * 2) / ('x5' * (-1*'x5'))) + (('x4' * 'x5' * 'x3') / ('x5' * SQR('x4') * (-1))))", formatter.Format(dfdx4)); // -2*x1*x2*x5*x4/x5*1/x5 + x3*1/(x5/x4)*x5/sqr(x4) var dfdx5 = DerivativeCalculator.Derive(expr, "x5"); Assert.AreEqual("((('x4' * 'x3') / ('x5' * 'x4')) + ('x1' * 'x2' * ((SQR(('x4' / 'x5')) * (-1)) + 1)) + (('x1' * 'x2' * 'x5' * ('x4' * 'x4') * 2) / ('x5' * SQR('x5') * 1)))", formatter.Format(dfdx5)); } } private void AssertInterval(string expression, Dictionary intervals, double expectedLow, double expectedHigh) { var parser = new InfixExpressionParser(); var t = parser.Parse(expression); var evaluator = new IntervalEvaluator(); var result = evaluator.Evaluate(t, intervals); Assert.AreEqual(expectedLow, result.LowerBound); Assert.AreEqual(expectedHigh, result.UpperBound); } } }