class Mirlo::TestResult

Attributes

n_samples[R]

Public Class Methods

new(possible_classes = []) click to toggle source
# File lib/mirlo/test_result.rb, line 4
def initialize(possible_classes = [])
  @possible_classes = possible_classes
  @confusion_matrix = Hash.new { 0 }
  @n_samples = 0
end

Public Instance Methods

add(sample, prediction) click to toggle source
# File lib/mirlo/test_result.rb, line 10
def add(sample, prediction)
  @possible_classes << sample.target unless @possible_classes.include?(sample.target)
  @confusion_matrix[[sample.target, prediction]] += 1
  @n_samples += 1
end
confusion_matrix(expected, prediction) click to toggle source
# File lib/mirlo/test_result.rb, line 16
def confusion_matrix(expected, prediction)
  @confusion_matrix[[expected, prediction]]
end
error_percentage() click to toggle source
# File lib/mirlo/test_result.rb, line 38
def error_percentage
  n_errors.to_f/n_samples
end
mean_squared_error() click to toggle source
# File lib/mirlo/test_result.rb, line 20
def mean_squared_error
  errors = @confusion_matrix.collect do |results, times|
    expected, prediction = results
    error_for(expected, prediction, times)
  end

  errors.inject(:+)
end
n_errors() click to toggle source
# File lib/mirlo/test_result.rb, line 29
def n_errors
  errors = @confusion_matrix.select do |results, times|
    expected, prediction = results
    expected != prediction
  end

  errors.collect { |results, times| times }.inject(:+)
end

Private Instance Methods

error_for(expected, prediction, times) click to toggle source
# File lib/mirlo/test_result.rb, line 44
def error_for(expected, prediction, times)
  diffs = expected.each_with_index.collect { |expected_val, i| expected_val - prediction[i] }
  squared_errors = diffs.collect { |diff| diff ** 2 }
  squared_errors.inject(:+) * times
end