class TensorStream::OpMaker

Attributes

aliases[R]
check_types[R]
custom[R]
custom_post[R]
data_type_block[R]
data_type_coercion[R]
description[R]
exclude[R]
gradient[R]
infer_type_proc[R]
operation[R]
options[R]
parameters[R]
supports_broadcast[R]

Public Class Methods

define_operation(op_code, &block) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 47
def self.define_operation(op_code, &block)
  @ops ||= {}
  op_maker = TensorStream::OpMaker.new(op_code.to_sym)
  block.call(op_maker)
  @ops[op_code.to_sym] = op_maker
end
each_op(&block) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 83
def self.each_op(&block)
  @ops.values.sort_by { |op| op.operation }.reject(&:exclude).each do |op|
    block.call(op)
  end
end
gradient_op(context_caller, node, grad) click to toggle source

call an operations' gradient definition

# File lib/tensor_stream/op_maker.rb, line 55
def self.gradient_op(context_caller, node, grad)
  raise "No derivative op defined for #{node.operation}" if @ops[node.operation].nil? || @ops[node.operation].gradient.nil?

  context_caller.instance_exec(grad, node, node.inputs, &@ops[node.operation].gradient)
end
infer_data_type(context_caller, tensor, passed_data_type) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 67
def self.infer_data_type(context_caller, tensor, passed_data_type)
  return passed_data_type if passed_data_type

  if @ops[tensor.operation] && @ops[tensor.operation].data_type_block
    context_caller.instance_exec(tensor, &@ops[tensor.operation].data_type_block)
  else
    if tensor.inputs[0]
      tensor.inputs[0].data_type
    elsif tensor.inputs[1]
      tensor.inputs[1].data_type
    else
      :unknown
    end
  end
end
infer_shape(context_caller, tensor) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 61
def self.infer_shape(context_caller, tensor)
  return nil unless @ops[tensor.operation]

  context_caller.instance_exec(tensor, &@ops[tensor.operation].infer_type_proc)
end
new(op) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 8
def initialize(op)
  @operation = op
  @parameters = []
  @options = {}
  @gradient = nil
  @supports_broadcast = false
  @data_type_coercion = false
  @exclude = false
  @description = []
  @aliases = []
  @custom = []
  @custom_post = []
  @infer_type_proc = lambda { |tensor|
    next nil if tensor.inputs[0].nil?
    next tensor.inputs[0].shape.shape if tensor.inputs.size == 1

    TensorStream::TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
  }
end
scan() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 40
def self.scan
  op_files = Dir[File.join(File.dirname(__FILE__), "ops", "*.rb")]
  op_files.each { |file|
    load File.join("tensor_stream", "ops", File.basename(file))
  }
end

Public Instance Methods

add_custom(custom_code) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 32
def add_custom(custom_code)
  @custom << custom_code
end
add_custom_post(custom_code) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 36
def add_custom_post(custom_code)
  @custom_post << custom_code
end
apply_data_type_coercion!() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 168
def apply_data_type_coercion!
  @data_type_coercion = true
end
check_types?() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 188
def check_types?
  @check_types
end
data_type_coercion?() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 184
def data_type_coercion?
  @data_type_coercion
end
default_with_nil(v) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 208
def default_with_nil(v)
  v == :nil ? 'nil' : v
end
define_data_type(&block) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 154
def define_data_type(&block)
  @data_type_block = block
end
define_gradient(&block) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 146
def define_gradient(&block)
  @gradient = block
end
define_shape(&block) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 150
def define_shape(&block)
  @infer_type_proc = block
end
description_lines() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 101
def description_lines
  description.map { |line| line.split("\n") }.flatten
end
exclude!() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 97
def exclude!
  @exclude = true
end
expand_options(print_defaults) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 192
def expand_options(print_defaults)
  @options.map { |k, v|
    print_defaults && v[:default_value] ? "#{k}: #{default_with_nil(v[:default_value])}" : "#{k}:"
  }
end
expand_params(print_defaults) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 158
def expand_params(print_defaults)
  @parameters.map { |param|
    print_defaults && param[:default_value] ? "#{param[:name]} = #{default_with_nil(param[:default_value])}" : "#{param[:name]}"
  }
end
generate_body() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 105
def generate_body
  body = []
  parameters.select { |p| p[:validate] }.each do |p|
    body << "check_allowed_types(#{p[:name]}, TensorStream::Ops::#{p[:validate]})"
  end
  if data_type_coercion?
    body << "#{expand_params(false).join(', ')} = apply_data_type_coercion(#{expand_params(false).join(', ')})"
  end
  if check_types?
    body << "check_data_types(#{expand_params(false).join(', ')})"
  end
  custom.each do |c|
    body << c
  end
  if custom_post.empty?
    body << "_op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
  else
    body << "result = _op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
  end
  custom_post.each do |c|
    body << c
  end
  body.map { |line| "      #{line}"}.join("\n")
end
option(name, description, default_value = nil, options = {}) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 142
def option(name, description, default_value = nil, options = {})
  @options[name] = { description: description, default_value: default_value, options: options }
end
options_call() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 198
def options_call
  @options.reject { |k, v| v.dig(:options, :exclude) }.map { |k, v|
    if v.dig(:options, :alias)
      "#{v.dig(:options, :alias)}: #{k}"
    else
      "#{k}: #{k}"
    end
  }
end
other_names(aliases) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 28
def other_names(aliases)
  @aliases += aliases
end
parameter(name, description, default_value = nil, validate: nil) click to toggle source

adds a parameter to the op

# File lib/tensor_stream/op_maker.rb, line 133
def parameter(name, description, default_value = nil, validate: nil)
  @parameters << {
    name: name.to_s,
    description: description,
    default_value: default_value,
    validate: validate
  }
end
parameters_must_have_same_data_type!() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 164
def parameters_must_have_same_data_type!
  @check_types = true
end
supports_broadcasting!() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 172
def supports_broadcasting!
  if (@parameters.size> 1)
    @supports_broadcast = true
  else
    raise "Ops with parameters < 2 cannot support broadcasting"
  end
end
supports_broadcasting?() click to toggle source
# File lib/tensor_stream/op_maker.rb, line 180
def supports_broadcasting?
  @supports_broadcast
end
what_it_does(description) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 89
def what_it_does(description)
  @description << description
end
what_it_does_code(description) click to toggle source
# File lib/tensor_stream/op_maker.rb, line 93
def what_it_does_code(description)
  @description << "<tt>#{description}</tt>"
end