class TensorStream::OpMaker
Attributes
aliases[R]
check_types[R]
custom[R]
custom_post[R]
data_type_block[R]
data_type_coercion[R]
description[R]
exclude[R]
gradient[R]
infer_type_proc[R]
operation[R]
options[R]
parameters[R]
supports_broadcast[R]
Public Class Methods
define_operation(op_code, &block)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 47 def self.define_operation(op_code, &block) @ops ||= {} op_maker = TensorStream::OpMaker.new(op_code.to_sym) block.call(op_maker) @ops[op_code.to_sym] = op_maker end
each_op(&block)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 83 def self.each_op(&block) @ops.values.sort_by { |op| op.operation }.reject(&:exclude).each do |op| block.call(op) end end
gradient_op(context_caller, node, grad)
click to toggle source
call an operations' gradient definition
# File lib/tensor_stream/op_maker.rb, line 55 def self.gradient_op(context_caller, node, grad) raise "No derivative op defined for #{node.operation}" if @ops[node.operation].nil? || @ops[node.operation].gradient.nil? context_caller.instance_exec(grad, node, node.inputs, &@ops[node.operation].gradient) end
infer_data_type(context_caller, tensor, passed_data_type)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 67 def self.infer_data_type(context_caller, tensor, passed_data_type) return passed_data_type if passed_data_type if @ops[tensor.operation] && @ops[tensor.operation].data_type_block context_caller.instance_exec(tensor, &@ops[tensor.operation].data_type_block) else if tensor.inputs[0] tensor.inputs[0].data_type elsif tensor.inputs[1] tensor.inputs[1].data_type else :unknown end end end
infer_shape(context_caller, tensor)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 61 def self.infer_shape(context_caller, tensor) return nil unless @ops[tensor.operation] context_caller.instance_exec(tensor, &@ops[tensor.operation].infer_type_proc) end
new(op)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 8 def initialize(op) @operation = op @parameters = [] @options = {} @gradient = nil @supports_broadcast = false @data_type_coercion = false @exclude = false @description = [] @aliases = [] @custom = [] @custom_post = [] @infer_type_proc = lambda { |tensor| next nil if tensor.inputs[0].nil? next tensor.inputs[0].shape.shape if tensor.inputs.size == 1 TensorStream::TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1] } end
scan()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 40 def self.scan op_files = Dir[File.join(File.dirname(__FILE__), "ops", "*.rb")] op_files.each { |file| load File.join("tensor_stream", "ops", File.basename(file)) } end
Public Instance Methods
add_custom(custom_code)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 32 def add_custom(custom_code) @custom << custom_code end
add_custom_post(custom_code)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 36 def add_custom_post(custom_code) @custom_post << custom_code end
apply_data_type_coercion!()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 168 def apply_data_type_coercion! @data_type_coercion = true end
check_types?()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 188 def check_types? @check_types end
data_type_coercion?()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 184 def data_type_coercion? @data_type_coercion end
default_with_nil(v)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 208 def default_with_nil(v) v == :nil ? 'nil' : v end
define_data_type(&block)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 154 def define_data_type(&block) @data_type_block = block end
define_gradient(&block)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 146 def define_gradient(&block) @gradient = block end
define_shape(&block)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 150 def define_shape(&block) @infer_type_proc = block end
description_lines()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 101 def description_lines description.map { |line| line.split("\n") }.flatten end
exclude!()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 97 def exclude! @exclude = true end
expand_options(print_defaults)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 192 def expand_options(print_defaults) @options.map { |k, v| print_defaults && v[:default_value] ? "#{k}: #{default_with_nil(v[:default_value])}" : "#{k}:" } end
expand_params(print_defaults)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 158 def expand_params(print_defaults) @parameters.map { |param| print_defaults && param[:default_value] ? "#{param[:name]} = #{default_with_nil(param[:default_value])}" : "#{param[:name]}" } end
generate_body()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 105 def generate_body body = [] parameters.select { |p| p[:validate] }.each do |p| body << "check_allowed_types(#{p[:name]}, TensorStream::Ops::#{p[:validate]})" end if data_type_coercion? body << "#{expand_params(false).join(', ')} = apply_data_type_coercion(#{expand_params(false).join(', ')})" end if check_types? body << "check_data_types(#{expand_params(false).join(', ')})" end custom.each do |c| body << c end if custom_post.empty? body << "_op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})" else body << "result = _op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})" end custom_post.each do |c| body << c end body.map { |line| " #{line}"}.join("\n") end
option(name, description, default_value = nil, options = {})
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 142 def option(name, description, default_value = nil, options = {}) @options[name] = { description: description, default_value: default_value, options: options } end
options_call()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 198 def options_call @options.reject { |k, v| v.dig(:options, :exclude) }.map { |k, v| if v.dig(:options, :alias) "#{v.dig(:options, :alias)}: #{k}" else "#{k}: #{k}" end } end
other_names(aliases)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 28 def other_names(aliases) @aliases += aliases end
parameter(name, description, default_value = nil, validate: nil)
click to toggle source
adds a parameter to the op
# File lib/tensor_stream/op_maker.rb, line 133 def parameter(name, description, default_value = nil, validate: nil) @parameters << { name: name.to_s, description: description, default_value: default_value, validate: validate } end
parameters_must_have_same_data_type!()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 164 def parameters_must_have_same_data_type! @check_types = true end
supports_broadcasting!()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 172 def supports_broadcasting! if (@parameters.size> 1) @supports_broadcast = true else raise "Ops with parameters < 2 cannot support broadcasting" end end
supports_broadcasting?()
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 180 def supports_broadcasting? @supports_broadcast end
what_it_does(description)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 89 def what_it_does(description) @description << description end
what_it_does_code(description)
click to toggle source
# File lib/tensor_stream/op_maker.rb, line 93 def what_it_does_code(description) @description << "<tt>#{description}</tt>" end