class Tensorflow::Keras::Metrics::Mean
Public Class Methods
new(name: nil, dtype: :float)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 5 def initialize(name: nil, dtype: :float) @dtype = dtype @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype) @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype) end
Public Instance Methods
call(*args)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 11 def call(*args) update_state(*args) end
reset_states()
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 25 def reset_states end
result()
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 21 def result RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float)) end
update_state(values)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 15 def update_state(values) input = Tensorflow.cast(input, destination_dtype: @dtype) @total.assign_add(Math.reduce_sum(input)) @count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype)) end