module Chainer::Utils::Variable

Public Class Methods

check_grad_type(func, x, gx) click to toggle source
# File lib/chainer/utils/variable.rb, line 4
def self.check_grad_type(func, x, gx)
  if x.data.nil? || gx.nil?
    return
  end

  unless gx.is_a?(x.data.class.superclass)
    raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
  end

  unless gx.class == x.data.class
    raise TypeError, "Dtype(Class) of data and grad mismatch\n#{x.data.class} != #{gx.class}"
  end

  unless gx.shape == x.data.shape
    raise TypeError, "Shape of data and grad mismatch\n#{x.data.shape} != #{gx.shape}"
  end
end