#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Problems.DataAnalysis; using HEAL.Attic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableType("7B4D9AE9-0456-4029-80A6-CCB5E33CE356")] public class RegressionRuleSetModel : RegressionModel, IDecisionTreeModel { private const string NumRulesResultName = "Number of rules"; private const string CoveredInstancesResultName = "Covered instances"; public const string RuleSetStateVariableName = "RuleSetState"; #region Properties [Storable] internal List Rules { get; private set; } #endregion #region HLConstructors & Cloning [StorableConstructor] protected RegressionRuleSetModel(StorableConstructorFlag _) : base(_) { } protected RegressionRuleSetModel(RegressionRuleSetModel original, Cloner cloner) : base(original, cloner) { if (original.Rules != null) Rules = original.Rules.Select(cloner.Clone).ToList(); } protected RegressionRuleSetModel(string targetVariable) : base(targetVariable) { } public override IDeepCloneable Clone(Cloner cloner) { return new RegressionRuleSetModel(this, cloner); } #endregion internal static RegressionRuleSetModel CreateRuleModel(string targetAttr, RegressionTreeParameters regressionTreeParams) { return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionRuleSetModel(targetAttr) : new RegressionRuleSetModel(targetAttr); } #region RegressionModel public override IEnumerable VariablesUsedForPrediction { get { var f = Rules.FirstOrDefault(); return f != null ? (f.VariablesUsedForPrediction ?? new List()) : new List(); } } public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) { if (Rules == null) throw new NotSupportedException("The model has not been built yet"); return rows.Select(row => GetEstimatedValue(dataset, row)); } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, problemData); } #endregion #region IDecisionTreeModel public void Build(IReadOnlyList trainingRows, IReadOnlyList pruningRows, IScope stateScope, ResultCollection results, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; var ruleSetState = (RuleSetState)stateScope.Variables[RuleSetStateVariableName].Value; if (ruleSetState.Code <= 0) { ruleSetState.Rules.Clear(); ruleSetState.TrainingRows = trainingRows; ruleSetState.PruningRows = pruningRows; ruleSetState.Code = 1; } do { var tempRule = RegressionRuleModel.CreateRuleModel(regressionTreeParams.TargetVariable, regressionTreeParams); cancellationToken.ThrowIfCancellationRequested(); if (!results.ContainsKey(NumRulesResultName)) results.Add(new Result(NumRulesResultName, new IntValue(0))); if (!results.ContainsKey(CoveredInstancesResultName)) results.Add(new Result(CoveredInstancesResultName, new IntValue(0))); var t1 = ruleSetState.TrainingRows.Count; tempRule.Build(ruleSetState.TrainingRows, ruleSetState.PruningRows, stateScope, results, cancellationToken); ruleSetState.TrainingRows = ruleSetState.TrainingRows.Where(i => !tempRule.Covers(regressionTreeParams.Data, i)).ToArray(); ruleSetState.PruningRows = ruleSetState.PruningRows.Where(i => !tempRule.Covers(regressionTreeParams.Data, i)).ToArray(); ruleSetState.Rules.Add(tempRule); ((IntValue)results[NumRulesResultName].Value).Value++; ((IntValue)results[CoveredInstancesResultName].Value).Value += t1 - ruleSetState.TrainingRows.Count; } while (ruleSetState.TrainingRows.Count > 0); Rules = ruleSetState.Rules; } public void Update(IReadOnlyList rows, IScope stateScope, CancellationToken cancellationToken) { foreach (var rule in Rules) rule.Update(rows, stateScope, cancellationToken); } public static void Initialize(IScope stateScope) { stateScope.Variables.Add(new Variable(RuleSetStateVariableName, new RuleSetState())); } #endregion #region Helpers private double GetEstimatedValue(IDataset dataset, int row) { foreach (var rule in Rules) { if (rule.Covers(dataset, row)) return rule.GetEstimatedValues(dataset, row.ToEnumerable()).Single(); } throw new ArgumentException("Instance is not covered by any rule"); } #endregion [StorableType("E114F3C9-3C1F-443D-8270-0E10CE12F2A0")] public class RuleSetState : Item { [Storable] public List Rules = new List(); [Storable] public IReadOnlyList TrainingRows = new List(); [Storable] public IReadOnlyList PruningRows = new List(); //State.Code values denote the current action (for pausing) //0...nothing has been done; //1...splitting nodes; [Storable] public int Code = 0; #region HLConstructors & Cloning [StorableConstructor] protected RuleSetState(StorableConstructorFlag _) : base(_) { } protected RuleSetState(RuleSetState original, Cloner cloner) : base(original, cloner) { Rules = original.Rules.Select(cloner.Clone).ToList(); TrainingRows = original.TrainingRows.ToList(); PruningRows = original.PruningRows.ToList(); Code = original.Code; } public RuleSetState() { } public override IDeepCloneable Clone(Cloner cloner) { return new RuleSetState(this, cloner); } #endregion } [StorableType("52E7992B-94CC-4960-AA82-1A399BE735C6")] private sealed class ConfidenceRegressionRuleSetModel : RegressionRuleSetModel, IConfidenceRegressionModel { #region HLConstructors & Cloning [StorableConstructor] private ConfidenceRegressionRuleSetModel(StorableConstructorFlag _) : base(_) { } private ConfidenceRegressionRuleSetModel(ConfidenceRegressionRuleSetModel original, Cloner cloner) : base(original, cloner) { } public ConfidenceRegressionRuleSetModel(string targetVariable) : base(targetVariable) { } public override IDeepCloneable Clone(Cloner cloner) { return new ConfidenceRegressionRuleSetModel(this, cloner); } #endregion #region IConfidenceRegressionModel public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) { if (Rules == null) throw new NotSupportedException("The model has not been built yet"); return rows.Select(row => GetEstimatedVariance(dataset, row)); } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new ConfidenceRegressionSolution(this, problemData); } private double GetEstimatedVariance(IDataset dataset, int row) { foreach (var rule in Rules) { if (rule.Covers(dataset, row)) return ((IConfidenceRegressionModel)rule).GetEstimatedVariances(dataset, row.ToEnumerable()).Single(); } throw new ArgumentException("Instance is not covered by any rule"); } #endregion } } }