#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HEAL.Attic;
using HeuristicLab.Problems.DataAnalysis;
namespace HeuristicLab.Algorithms.DataAnalysis {
///
/// Represents a nearest neighbour model for regression and classification
///
[StorableType("04A07DF6-6EB5-4D29-B7AE-5BE204CAF6BC")]
[Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
private alglib.knnmodel model;
[Storable]
private string SerializedModel {
get { alglib.knnserialize(model, out var ser); return ser; }
set { if (value != null) alglib.knnunserialize(value, out model); }
}
public override IEnumerable VariablesUsedForPrediction {
get { return allowedInputVariables; }
}
[Storable]
private string[] allowedInputVariables;
[Storable]
private double[] classValues;
[Storable]
private int k;
[Storable]
private double[] weights;
[Storable]
private double[] offsets;
[StorableConstructor]
private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { }
private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
: base(original, cloner) {
if (original.model != null)
model = (alglib.knnmodel)original.model.make_copy();
k = original.k;
weights = new double[original.weights.Length];
Array.Copy(original.weights, weights, weights.Length);
offsets = new double[original.offsets.Length];
Array.Copy(original.offsets, this.offsets, this.offsets.Length);
allowedInputVariables = (string[])original.allowedInputVariables.Clone();
if (original.classValues != null)
this.classValues = (double[])original.classValues.Clone();
}
public NearestNeighbourModel(IDataset dataset, IEnumerable rows, int k, string targetVariable, IEnumerable allowedInputVariables, IEnumerable weights = null, double[] classValues = null)
: base(targetVariable) {
Name = ItemName;
Description = ItemDescription;
this.k = k;
this.allowedInputVariables = allowedInputVariables.ToArray();
double[,] inputMatrix;
this.offsets = this.allowedInputVariables
.Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
.Concat(new double[] { 0 }) // no offset for target variable
.ToArray();
if (weights == null) {
// automatic determination of weights (all features should have variance = 1)
this.weights = this.allowedInputVariables
.Select(name => {
var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
})
.Concat(new double[] { 1.0 }) // no scaling for target variable
.ToArray();
} else {
// user specified weights (+ 1 for target)
this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
if (this.weights.Length - 1 != this.allowedInputVariables.Length)
throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
}
inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
if (inputMatrix.ContainsNanOrInfinity())
throw new NotSupportedException(
"Nearest neighbour model does not support NaN or infinity values in the input dataset.");
var nRows = inputMatrix.GetLength(0);
var nFeatures = inputMatrix.GetLength(1) - 1;
if (classValues != null) {
this.classValues = (double[])classValues.Clone();
int nClasses = classValues.Length;
// map original class values to values [0..nClasses-1]
var classIndices = new Dictionary();
for (int i = 0; i < nClasses; i++)
classIndices[classValues[i]] = i;
for (int row = 0; row < nRows; row++) {
inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]];
}
}
alglib.knnbuildercreate(out var knnbuilder);
if (classValues == null) {
alglib.knnbuildersetdatasetreg(knnbuilder, inputMatrix, nRows, nFeatures, nout: 1);
} else {
alglib.knnbuildersetdatasetcls(knnbuilder, inputMatrix, nRows, nFeatures, classValues.Length);
}
alglib.knnbuilderbuildknnmodel(knnbuilder, k, eps: 0.0, out model, out var report); // eps=0 (exact k-nn search is performed)
}
private static double[,] CreateScaledData(IDataset dataset, IEnumerable variables, IEnumerable rows, double[] offsets, double[] factors) {
var transforms =
variables.Select(
(_, colIdx) =>
new LinearTransformation(variables) { Addend = offsets[colIdx] * factors[colIdx], Multiplier = factors[colIdx] });
return dataset.ToArray(variables, transforms, rows);
}
public override IDeepCloneable Clone(Cloner cloner) {
return new NearestNeighbourModel(this, cloner);
}
public IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
double[,] inputData;
inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
alglib.knncreatebuffer(model, out var buf);
var y = new double[1];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
yield return y[0];
}
}
public override IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) {
if (classValues == null) throw new InvalidOperationException("No class values are defined.");
double[,] inputData;
inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
alglib.knncreatebuffer(model, out var buf);
var y = new double[classValues.Length];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
// find most probably class
var maxC = 0;
for (int i = 1; i < y.Length; i++)
if (maxC < y[i]) maxC = i;
yield return classValues[maxC];
}
}
public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
}
public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
var regressionProblemData = problemData as IRegressionProblemData;
if (regressionProblemData != null)
return IsProblemDataCompatible(regressionProblemData, out errorMessage);
var classificationProblemData = problemData as IClassificationProblemData;
if (classificationProblemData != null)
return IsProblemDataCompatible(classificationProblemData, out errorMessage);
throw new ArgumentException("The problem data is not compatible with this nearest neighbour model. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
}
IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
return new NearestNeighbourRegressionSolution(this, new RegressionProblemData(problemData));
}
public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
}
}
}