/* * SVM.NET Library * Copyright (C) 2008 Matthew Johnson * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ using System; using System.Collections.Generic; using System.IO; using System.Globalization; namespace SVM { /// /// Class encoding a member of a ranked set of labels. /// public class RankPair : IComparable { private double _score, _label; /// /// Constructor. /// /// Score for this pair /// Label associated with the given score public RankPair(double score, double label) { _score = score; _label = label; } /// /// The score for this pair. /// public double Score { get { return _score; } } /// /// The Label for this pair. /// public double Label { get { return _label; } } #region IComparable Members /// /// Compares this pair to another. It will end up in a sorted list in decending score order. /// /// The pair to compare to /// Whether this should come before or after the argument public int CompareTo(RankPair other) { return other.Score.CompareTo(Score); } #endregion /// /// Returns a string representation of this pair. /// /// A string in the for Score:Label public override string ToString() { return string.Format("{0}:{1}", Score, Label); } } /// /// Class encoding the point on a 2D curve. /// public class CurvePoint { private float _x, _y; /// /// Constructor. /// /// X coordinate /// Y coordinate public CurvePoint(float x, float y) { _x = x; _y = y; } /// /// X coordinate /// public float X { get { return _x; } } /// /// Y coordinate /// public float Y { get { return _y; } } /// /// Creates a string representation of this point. /// /// string in the form (x, y) public override string ToString() { return string.Format("({0}, {1})", _x, _y); } } /// /// Class which evaluates an SVM model using several standard techniques. /// public class PerformanceEvaluator { private class ChangePoint { public ChangePoint(int tp, int fp, int tn, int fn) { TP = tp; FP = fp; TN = tn; FN = fn; } public int TP, FP, TN, FN; public override string ToString() { return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN); } } private List _prCurve; private double _ap; private List _rocCurve; private double _auc; private List _data; private List _changes; /// /// Constructor. /// /// A pre-computed ranked pair set public PerformanceEvaluator(List set) { _data = set; computeStatistics(); } /// /// Constructor. /// /// Model to evaluate /// Problem to evaluate /// Label to be evaluate for public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { } /// /// Constructor. /// /// Model to evaluate /// Problem to evaluate /// Results file for output /// Category to evaluate for public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile) { Prediction.Predict(problem, resultsFile, model, true); parseResultsFile(resultsFile, problem.Y, category); computeStatistics(); } /// /// Constructor. /// /// Results file /// The correct labels of each data item /// The category to evaluate for public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category) { parseResultsFile(resultsFile, correctLabels, category); computeStatistics(); } private void parseResultsFile(string resultsFile, double[] labels, double category) { StreamReader input = new StreamReader(resultsFile); string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); int confidenceIndex = -1; for (int i = 1; i < parts.Length; i++) if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category) { confidenceIndex = i; break; } _data = new List(); for (int i = 0; i < labels.Length; i++) { parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture); _data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0)); } input.Close(); } private void computeStatistics() { _data.Sort(); findChanges(); computePR(); computeRoC(); } private void findChanges() { int tp, fp, tn, fn; tp = fp = tn = fn = 0; for (int i = 0; i < _data.Count; i++) { if (_data[i].Label == 1) fn++; else tn++; } _changes = new List(); for (int i = 0; i < _data.Count; i++) { if (_data[i].Label == 1) { tp++; fn--; } else { fp++; tn--; } _changes.Add(new ChangePoint(tp, fp, tn, fn)); } } private float computePrecision(ChangePoint p) { return (float)p.TP / (p.TP + p.FP); } private float computeRecall(ChangePoint p) { return (float)p.TP / (p.TP + p.FN); } private void computePR() { _prCurve = new List(); _prCurve.Add(new CurvePoint(0, 1)); float precision = computePrecision(_changes[0]); float recall = computeRecall(_changes[0]); float precisionSum = 0; if (_changes[0].TP > 0) { precisionSum += precision; _prCurve.Add(new CurvePoint(recall, precision)); } for (int i = 1; i < _changes.Count; i++) { precision = computePrecision(_changes[i]); recall = computeRecall(_changes[i]); if (_changes[i].TP > _changes[i - 1].TP) { precisionSum += precision; _prCurve.Add(new CurvePoint(recall, precision)); } } _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN))); _ap = precisionSum / (_changes[0].FN + _changes[0].TP); } /// /// Writes the Precision-Recall curve to a tab-delimited file. /// /// Filename for output public void WritePRCurve(string filename) { StreamWriter output = new StreamWriter(filename); output.WriteLine(_ap); for (int i = 0; i < _prCurve.Count; i++) output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y); output.Close(); } /// /// Writes the Receiver Operating Characteristic curve to a tab-delimited file. /// /// Filename for output public void WriteROCCurve(string filename) { StreamWriter output = new StreamWriter(filename); output.WriteLine(_auc); for (int i = 0; i < _rocCurve.Count; i++) output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y); output.Close(); } /// /// Receiver Operating Characteristic curve /// public List ROCCurve { get { return _rocCurve; } } /// /// Returns the area under the ROC Curve /// public double AuC { get { return _auc; } } /// /// Precision-Recall curve /// public List PRCurve { get { return _prCurve; } } /// /// The average precision /// public double AP { get { return _ap; } } private float computeTPR(ChangePoint cp) { return computeRecall(cp); } private float computeFPR(ChangePoint cp) { return (float)cp.FP / (cp.FP + cp.TN); } private void computeRoC() { _rocCurve = new List(); _rocCurve.Add(new CurvePoint(0, 0)); float tpr = computeTPR(_changes[0]); float fpr = computeFPR(_changes[0]); _rocCurve.Add(new CurvePoint(fpr, tpr)); _auc = 0; for (int i = 1; i < _changes.Count; i++) { float newTPR = computeTPR(_changes[i]); float newFPR = computeFPR(_changes[i]); if (_changes[i].TP > _changes[i - 1].TP) { _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr); tpr = newTPR; fpr = newFPR; _rocCurve.Add(new CurvePoint(fpr, tpr)); } } _rocCurve.Add(new CurvePoint(1, 1)); _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr); } } }