class AprendizajeMaquina::DecisionTree
Public Class Methods
new(dataset)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 3 def initialize(dataset) @dataset = dataset end
Public Instance Methods
display_tree()
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 7 def display_tree node_root = build_tree(@dataset) colection = [node_root] branches = [] tree = "root --> #{node_root[1][0]}:#{node_root[1][1]}?\n" for node in 0...node_root[2].length branches << build_tree(node_root[2][node]) colection << branches 1000.times do subbranches = [] true_or_false = lambda { |node| node == 0 ? true : false } branches.each do |branch| if branch.is_a?(Array) tree << "#{true_or_false.call(node)} --> "+"#{branch[1][0]}:#{branch[1][1]}?\n" for node in 0...branch[2].length if build_tree(branch[2][node]).is_a? Hash tree << "#{true_or_false.call(node)} --> "+"#{build_tree(branch[2][node])}\n" else subbranches << build_tree(branch[2][node]) end end elsif branch.is_a?(Hash) tree << "#{true_or_false.call(node)} --> "+"#{branch}\n" end end branches = subbranches colection << branches if colection.last.empty? colection.pop break end end end return tree end
predict(observation)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 43 def predict(observation) node_root = build_tree(@dataset) until node_root.is_a?(Hash) if observation[node_root[1][0]].is_a?(Integer) or observation[node_root[1][0]].is_a?(Float) if observation[node_root[1][0]] >= node_root[1][1] branch = build_tree(node_root[2][0]) else branch = build_tree(node_root[2][1]) end else if observation[node_root[1][0]] == node_root[1][1] branch = build_tree(node_root[2][0]) else branch = build_tree(node_root[2][1]) end end node_root = branch end return node_root end
Private Instance Methods
build_tree(dataset)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 106 def build_tree(dataset) best_info_gain = 0.0 column_and_value_attribute = nil best_sets = nil for column_attribute in 0...dataset[0].length-1 # elimina la etiqueta for row in dataset value_attribute = row[column_attribute] node_true, node_false = split_dataset(dataset,column_attribute,value_attribute) information_gain = entropy(dataset) - (node_true.length.to_f/dataset.length) * entropy(node_true) - (node_false.length.to_f/dataset.length) * entropy(node_false) if information_gain > best_info_gain # pick the highest information_gain best_info_gain = information_gain column_and_value_attribute = column_attribute, value_attribute best_sets = node_true, node_false end end end if best_info_gain > 0 return best_info_gain, column_and_value_attribute, best_sets else return count_classes(dataset) end end
count_classes(dataset)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 84 def count_classes(dataset) hash_count = {} dataset.each do |row| if hash_count.include?(row[-1]) hash_count[row[-1]] += 1 else hash_count[row[-1]] = 1 end end return hash_count end
entropy(dataset)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 96 def entropy(dataset) classes_count = count_classes(dataset) ent = 0.0 classes_count.each_value do |value| prob = value.to_f / dataset.length ent -= prob * Math.log2(prob) end return ent end
split_dataset(dataset, column, value)
click to toggle source
# File lib/aprendizaje_maquina/decision_tree.rb, line 66 def split_dataset(dataset, column, value) if value.is_a? Integer or value.is_a? Float split_function = lambda { |row| row[column] >= value } else split_function = lambda { |row| row[column] == value } end set1 = [] set2 = [] for row in dataset if split_function.call(row) set1 << row else set2 << row end end return set1,set2 end