class Bae::Classifier

Attributes

frequency_table[RW]
label_index[RW]
label_index_sequence[RW]
label_instance_count[RW]
total_terms[RW]

Public Class Methods

new() click to toggle source
# File lib/bae/classifier.rb, line 7
def initialize
  @frequency_table = ::Hash.new
  @label_instance_count = ::Hash.new { |hash, label| hash[label] = 0 }
  @label_index = ::Hash.new { |hash, label| hash[label] = 0 }
  @label_index_sequence = -1 # start at -1 so 0 is first value
  @total_terms = 0.0
end

Public Instance Methods

classify(data) click to toggle source
# File lib/bae/classifier.rb, line 52
def classify(data)
  if data.is_a?(::String)
    classify_from_string(data)
  elsif data.is_a?(::Hash)
    classify_from_hash(data)
  else
    fail 'Training data must either be a string or hash'
  end
end
classify_from_hash(frequency_hash) click to toggle source
# File lib/bae/classifier.rb, line 62
def classify_from_hash(frequency_hash)
  document = frequency_hash.map{ |word, frequency| (word + ' ') * frequency }.join

  classify_from_string(document)
end
classify_from_string(document) click to toggle source
# File lib/bae/classifier.rb, line 68
def classify_from_string(document)
  words = document.split.uniq
  likelihoods = @likelihoods.dup
  posterior = {}

  vocab_size = @frequency_table_size

  label_index.each do |label, index|
    words.map do |word|
      row = frequency_table[word]

      unless row.nil?
        laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
        likelihoods[label] *= laplace_word_likelihood / (1.0 - laplace_word_likelihood)
      end
    end

    posterior[label] = @priors[label] * likelihoods[label]
  end

  normalize(posterior)
end
finish_training!() click to toggle source
# File lib/bae/classifier.rb, line 15
def finish_training!
  @frequency_table_size = @frequency_table.keys.size

  calculate_likelihoods!
  calculate_priors!
end
load_from_json(json) click to toggle source
# File lib/bae/classifier.rb, line 97
def load_from_json(json)
  state = ::JSON.parse(json)

  fail 'Missing frequency_table' unless state['frequency_table']
  fail 'Missing label_instance_count' unless state['label_instance_count']
  fail 'Missing label_index' unless state['label_index']
  fail 'Missing label_index_sequence' unless state['label_index_sequence']
  fail 'Missing total_terms' unless state['total_terms']

  @frequency_table = state['frequency_table']
  @label_instance_count = state['label_instance_count']
  @label_index = state['label_index']
  @label_index_sequence = state['label_index_sequence']
  @total_terms = state['total_terms']

  finish_training!
end
load_state(path) click to toggle source
# File lib/bae/classifier.rb, line 115
def load_state(path)
  state_json = ::File.read(::File.expand_path(path))
  load_from_json(state_json)
end
save_state(path) click to toggle source
# File lib/bae/classifier.rb, line 91
def save_state(path)
  ::File.open(::File.expand_path(path), 'w') do |handle|
    handle.write(to_json)
  end
end
to_json() click to toggle source
# File lib/bae/classifier.rb, line 120
def to_json
  state = {}
  state['frequency_table'] = frequency_table
  state['label_instance_count'] = label_instance_count
  state['label_index'] = label_index
  state['label_index_sequence'] = label_index_sequence
  state['total_terms'] = total_terms
  state.to_json
end
train(label, training_data) click to toggle source
# File lib/bae/classifier.rb, line 22
def train(label, training_data)
  if training_data.is_a?(::String)
    train_from_string(label, training_data)
  elsif training_data.is_a?(::Hash)
    train_from_hash(label, training_data)
  else
    fail 'Training data must either be a string or hash'
  end
end
train_from_hash(label, frequency_hash) click to toggle source
# File lib/bae/classifier.rb, line 43
def train_from_hash(label, frequency_hash)
  frequency_hash.each do |word, frequency|
    update_label_index(label)
    update_frequency_table(label, word, frequency)
  end
  @label_instance_count[label] += 1
  @total_terms += 1
end
train_from_string(label, document) click to toggle source
# File lib/bae/classifier.rb, line 32
def train_from_string(label, document)
  words = document.split

  words.each do |word|
    update_label_index(label)
    update_frequency_table(label, word, 1)
  end
  @label_instance_count[label] += 1
  @total_terms += 1
end

Private Instance Methods

calculate_likelihoods!() click to toggle source
# File lib/bae/classifier.rb, line 132
def calculate_likelihoods!
  @likelihoods = label_index.inject({}) do |accumulator, (label, index)|
    initial_likelihood = 1.0
    vocab_size = @frequency_table_size

    frequency_table.each do |feature, row|
      laplace_word_likelihood = (row[index] + 1.0).to_f / (label_instance_count[label] + vocab_size).to_f
      initial_likelihood *= (1.0 - laplace_word_likelihood)
    end

    accumulator[label] = initial_likelihood
    accumulator
  end
end
calculate_priors!() click to toggle source
# File lib/bae/classifier.rb, line 147
def calculate_priors!
  @priors = label_instance_count.inject({}) do |hash, (label, count)|
    hash[label] = count / total_terms
    hash
  end
end
get_next_sequence_value() click to toggle source
# File lib/bae/classifier.rb, line 154
def get_next_sequence_value
  @label_index_sequence += 1
end
normalize(posterior) click to toggle source
# File lib/bae/classifier.rb, line 158
def normalize(posterior)
  sum = posterior.inject(0.0) { |accumulator, (key, value)| accumulator + value }

  posterior.inject({}) do |accumulator, (key, value)|
    accumulator[key] = value / sum
    accumulator
  end
end
update_frequency_table(label, word, frequency) click to toggle source
# File lib/bae/classifier.rb, line 178
def update_frequency_table(label, word, frequency)
  row = frequency_table[word]
  index = label_index[label]

  if row
    row[index] += frequency
  else
    frequency_table[word] = label_index.keys.map { |label| 0 }
    frequency_table[word][index] += frequency
  end
end
update_label_index(label) click to toggle source
# File lib/bae/classifier.rb, line 167
def update_label_index(label)
  unless label_index.keys.include?(label)
    index = get_next_sequence_value
    label_index[label] = index

    frequency_table.each do |feature, value|
      value[index] = 0
    end
  end
end