class TensorStream::InferShape
Convenience class for guessing the shape of a tensor
Public Class Methods
_infer_reduction_op_shape(tensor)
click to toggle source
# File lib/tensor_stream/helpers/infer_shape.rb, line 186 def self._infer_reduction_op_shape(tensor) return [] if tensor.inputs[1].nil? return nil if tensor.inputs[0].nil? return nil unless tensor.inputs[0].shape.known? input_shape = tensor.inputs[0].shape.shape rank = input_shape.size axis = tensor.inputs[1].const_value return nil if axis.nil? axis = [axis] unless axis.is_a?(Array) axis = axis.map { |a| a < 0 ? rank - a.abs : a } input_shape.each_with_index.map { |item, index| if axis.include?(index) next 1 if tensor.options[:keepdims] next nil end item }.compact end
infer_shape(tensor)
click to toggle source
# File lib/tensor_stream/helpers/infer_shape.rb, line 10 def self.infer_shape(tensor) case tensor.operation when :assign tensor.inputs[0]&.shape&.shape when :const shape_eval(tensor.options[:value]) when :variable_v2 tensor.shape ? tensor.shape.shape : nil when :placeholder return nil if tensor.inputs[0].nil? return tensor.inputs[0].shape.shape if tensor.inputs.size == 1 TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1] when :case, :case_grad tensor.inputs[2]&.shape&.shape when :const shape_eval(tensor.options[:value]) when :variable_v2 tensor.shape ? tensor.shape.shape : nil when :assign possible_shape = if tensor.inputs[0]&.shape&.shape tensor.inputs[0].shape.shape else tensor.inputs[1].shape.shape end possible_shape when :index return nil unless tensor.inputs[0].is_a?(Tensor) return nil unless tensor.inputs[0].const_value input_shape = tensor.inputs[0].shape return nil unless input_shape.known? s = input_shape.shape.dup s.shift s when :arg_min, :argmax, :argmin return nil unless tensor.inputs[0].shape.known? return nil if tensor.inputs[1] && tensor.inputs[1].const_value.nil? axis = tensor.inputs[1].nil? ? 0 : tensor.inputs[1].const_value new_shape = tensor.inputs[0].shape.shape new_shape.each_with_index.collect { |shape, index| next nil if index == axis shape }.compact when :mean, :prod, :sum, :arg_max return [] if tensor.inputs[1].nil? return nil if tensor.inputs[0].nil? return nil unless tensor.inputs[0].shape.known? input_shape = tensor.inputs[0].shape.shape rank = input_shape.size axis = tensor.inputs[1].const_value return nil if axis.nil? axis = [axis] unless axis.is_a?(Array) axis = axis.map { |a| a < 0 ? rank - a.abs : a } input_shape.each_with_index.map { |item, index| if axis.include?(index) next 1 if tensor.options[:keepdims] next nil end item }.compact when :flow_group [] when :zeros, :ones, :fill, :random_standard_normal, :random_uniform, :truncated_normal a_shape = tensor.inputs[0] ? tensor.inputs[0].const_value : tensor.options[:shape] return nil if a_shape.nil? a_shape.is_a?(Array) ? a_shape : [a_shape] when :zeros_like, :ones_like tensor.inputs[0].shape.shape when :shape tensor.inputs[0].shape.shape ? [tensor.inputs[0].shape.shape.size] : nil when :pad return nil unless tensor.inputs[0].shape.known? return nil unless tensor.inputs[1].const_value size = tensor.inputs[0].shape.shape.reduce(:*) || 1 dummy_tensor_for_shape = TensorShape.reshape(Array.new(size), tensor.inputs[0].shape) shape_eval(arr_pad(dummy_tensor_for_shape, tensor.inputs[1].const_value)) when :transpose return nil unless shape_full_specified(tensor.inputs[0]) return nil if tensor.inputs[1].is_a?(Tensor) rank = tensor.inputs[0].shape.shape.size perm = tensor.inputs[1] || (0...rank).to_a.reverse perm.map { |p| tensor.inputs[0].shape.shape[p] } when :stack return nil unless shape_full_specified(tensor.inputs[0]) axis = tensor.options[:axis] || 0 new_shape = [tensor.inputs.size] tensor.inputs[0].shape.shape.inject(new_shape) { |ns, i| ns << i } rank = tensor.inputs[0].shape.shape.size + 1 axis = rank + axis if axis < 0 rotated_shape = Array.new(axis + 1) { new_shape.shift } rotated_shape.rotate! + new_shape when :concat return nil if tensor.inputs[0].const_value.nil? axis = tensor.inputs[0].const_value # get axis axis_size = 0 tensor.inputs[1..tensor.inputs.size].each do |input_item| return nil if input_item.shape.shape.nil? return nil if input_item.shape.shape[axis].nil? axis_size += input_item.shape.shape[axis] end new_shape = tensor.inputs[1].shape.shape.dup new_shape[axis] = axis_size new_shape when :slice, :squeeze nil when :broadcast_gradient_args nil when :no_op nil when :softmax_cross_entropy_with_logits_v2, :sparse_softmax_cross_entropy_with_logits nil when :decode_png, :flow_dynamic_stitch, :dynamic_stitch, :gather nil when :eye return [tensor.inputs[0].const_value, tensor.inputs[1].const_value] if tensor.inputs[0].const_value && tensor.inputs[1].const_value nil when :unstack return nil unless tensor.inputs[0].shape.known? new_shape = tensor.inputs[0].shape.shape.dup rank = new_shape.size - 1 axis = tensor.options[:axis] || 0 axis = rank + axis if axis < 0 rotated_shape = Array.new(axis + 1) { new_shape.shift } rotated_shape.rotate!(-1) + new_shape when :conv2d return nil unless tensor.inputs[0].shape.known? return nil unless tensor.inputs[1].shape.known? new_shape = tensor.inputs[0].shape.shape.dup new_shape[3] = tensor.inputs[1].shape.shape[3] # account for stride and padding options strides = tensor.options[:strides] case tensor.options[:padding] when "SAME" new_shape[1] /= strides[1] new_shape[2] /= strides[2] when "VALID" new_shape[1] = (new_shape[1] - tensor.inputs[1].shape.shape[0]) / strides[1] + 1 new_shape[2] = (new_shape[2] - tensor.inputs[1].shape.shape[1]) / strides[2] + 1 else raise TensorStream::ValueError, "Invalid padding option only 'SAME', 'VALID' accepted" end new_shape when :conv2d_backprop_input return nil unless tensor.inputs[0].const_value tensor.inputs[0].const_value else TensorStream::OpMaker.infer_shape(self, tensor) end end