using System; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Collections.Specialized; using System.Drawing.Design; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction { public static class Interpreter { private class Data { public double[] realGluc; public double[] realIns; public double[] realCh; public Dictionary precalculatedValues; } public static IEnumerable Apply(ISymbolicExpressionTreeNode model, IDataset dataset, IEnumerable rows) { double[] targetGluc = dataset.GetDoubleValues("Glucose_target", rows).ToArray(); // only for skipping rows for which we should not produce an output var data = new Data { realGluc = dataset.GetDoubleValues("Glucose_Interpol", rows).ToArray(), realIns = dataset.GetDoubleValues("Insuline", rows).ToArray(), realCh = dataset.GetDoubleValues("CH", rows).ToArray(), precalculatedValues = CreatePrecalculatedValues(model, dataset) }; var predictions = new double[targetGluc.Length]; var rowsEnumerator = rows.GetEnumerator(); rowsEnumerator.MoveNext(); for (int k = 0; k < predictions.Length; k++, rowsEnumerator.MoveNext()) { if (double.IsNaN(targetGluc[k])) { predictions[k] = double.NaN; } else { var rawPred = InterpretRec(model, data, rowsEnumerator.Current); predictions[k] = rawPred; } } return predictions; } private static Dictionary CreatePrecalculatedValues(ISymbolicExpressionTreeNode root, IDataset dataset) { var dict = new Dictionary(); // here we integrate ins or ch inputs over the whole day to generate smoothed ins/ch values with the same number of rows // the integrated values are reset to zero whenever a new evluation period starts foreach (var node in root.IterateNodesPrefix()) { var curvedInsNode = node as CurvedInsVariableTreeNode; var curvedChNode = node as CurvedChVariableTreeNode; if (curvedInsNode != null) { dict.Add(curvedInsNode, Integrate(curvedInsNode, dataset)); } else if (curvedChNode != null) { dict.Add(curvedChNode, Integrate(curvedChNode, dataset)); } } return dict; } private static double[] Integrate(CurvedInsVariableTreeNode node, IDataset dataset) { // d Q1 / dt = ins(t) - alpha * Q1(t) // d Q2 / dt = alpha * (Q1(t) - Q2(t)) // d Q3 / dt = alpha * Q2(t) - beta * Q3(t) var alpha = node.Alpha; var beta = node.Beta; var ins = dataset.GetReadOnlyDoubleValues("Insuline"); var time = dataset.GetReadOnlyDoubleValues("HourMin").ToArray(); double q1, q2, q3, q1_prev, q2_prev, q3_prev; // starting values: zeros q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0; double[] s = new double[dataset.Rows]; for (int t = 1; t < dataset.Rows; t++) { if (IsStartOfNewPeriod(time, t)) { q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0; } q1 = q1_prev + ins[t] - alpha * q1_prev; q2 = q2_prev + alpha * (q1_prev - q2_prev); q3 = q3_prev + alpha * q2_prev - beta * q3_prev; s[t] = q3; q1_prev = q1; q2_prev = q2; q3_prev = q3; } return s; } private static bool IsStartOfNewPeriod(double[] time, int t) { return t == 0 || (time[t].IsAlmost(2005) && !time[t - 1].IsAlmost(2000)); } private static double[] Integrate(CurvedChVariableTreeNode node, IDataset dataset) { // d Q1 / dt = ins(t) - alpha * Q1(t) // d Q2 / dt = alpha * (Q1(t) - Q2(t)) // d Q3 / dt = alpha * Q2(t) - beta * Q3(t) var alpha = node.Alpha; var beta = node.Beta; var ins = dataset.GetReadOnlyDoubleValues("CH"); var time = dataset.GetReadOnlyDoubleValues("HourMin").ToArray(); double q1, q2, q3, q1_prev, q2_prev, q3_prev; // starting values: zeros q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0; double[] s = new double[dataset.Rows]; for (int t = 1; t < dataset.Rows; t++) { if (IsStartOfNewPeriod(time, t)) { q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0; } q1 = q1_prev + ins[t] - alpha * q1_prev; q2 = q2_prev + alpha * (q1_prev - q2_prev); q3 = q3_prev + alpha * q2_prev - beta * q3_prev; s[t] = q3; q1_prev = q1; q2_prev = q2; q3_prev = q3; } return s; } private static double InterpretRec(ISymbolicExpressionTreeNode node, Data data, int k) { if (node.Symbol is SimpleSymbol) { switch (node.Symbol.Name) { case "+": case "+Ins": case "+Ch": { return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k); } case "-": case "-Ins": case "-Ch": { return InterpretRec(node.GetSubtree(0), data, k) - InterpretRec(node.GetSubtree(1), data, k); } case "*": case "*Ins": case "*Ch": { return InterpretRec(node.GetSubtree(0), data, k) * InterpretRec(node.GetSubtree(1), data, k); } case "/Ch": case "/Ins": case "/": { return InterpretRec(node.GetSubtree(0), data, k) / InterpretRec(node.GetSubtree(1), data, k); } case "Exp": case "ExpIns": case "ExpCh": { return Math.Exp(InterpretRec(node.GetSubtree(0), data, k)); } case "Sin": case "SinIns": case "SinCh": { return Math.Sin(InterpretRec(node.GetSubtree(0), data, k)); } case "CosCh": case "CosIns": case "Cos": { return Math.Cos(InterpretRec(node.GetSubtree(0), data, k)); } case "LogCh": case "LogIns": case "Log": { return Math.Log(InterpretRec(node.GetSubtree(0), data, k)); } case "Func": { // + - return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k) - InterpretRec(node.GetSubtree(2), data, k); } case "ExprGluc": { return InterpretRec(node.GetSubtree(0), data, k); } case "ExprCh": { return InterpretRec(node.GetSubtree(0), data, k); } case "ExprIns": { return InterpretRec(node.GetSubtree(0), data, k); } default: { throw new InvalidProgramException("Found an unknown symbol " + node.Symbol); } } } else if (node.Symbol is PredictedGlucoseVariableSymbol) { throw new NotSupportedException(); } else if (node.Symbol is RealGlucoseVariableSymbol) { var n = (RealGlucoseVariableTreeNode)node; if (k + n.RowOffset < 0 || k + n.RowOffset >= data.realGluc.Length) return double.NaN; return data.realGluc[k + n.RowOffset]; } else if (node.Symbol is CurvedChVariableSymbol) { return data.precalculatedValues[node][k]; } else if (node.Symbol is RealInsulineVariableSymbol) { throw new NotSupportedException(); } else if (node.Symbol is CurvedInsVariableSymbol) { return data.precalculatedValues[node][k]; } else if (node.Symbol is Constant) { var n = (ConstantTreeNode)node; return n.Value; } else { throw new InvalidProgramException("found unknown symbol " + node.Symbol); } } } }