class Secryst::ClipGradNorm
Public Class Methods
clip_grad_norm(parameters, max_norm:, norm_type:2)
click to toggle source
# File lib/secryst/clip_grad_norm.rb, line 5 def self.clip_grad_norm(parameters, max_norm:, norm_type:2) parameters = parameters.select {|p| p.grad } max_norm = max_norm.to_f if parameters.length == 0 return Torch.tensor(0.0) end device = parameters[0].grad.device if norm_type == Float::INFINITY # ... TODO else total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type) end clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1 parameters.each {|p| p.grad = p.grad.detach * clip_coef} end return total_norm end