class NSFW::Model

Constants

CATEGORIES
MODEL_PATH
SAFETY_THRESHOLD

Attributes

model[R]

Public Class Methods

new(lazy: false) click to toggle source
# File lib/nsfw/model.rb, line 11
def initialize(lazy: false)
  load_model! unless lazy
end

Public Instance Methods

loaded?() click to toggle source
# File lib/nsfw/model.rb, line 25
def loaded?
  !@model.nil?
end
predict(tensor) click to toggle source
# File lib/nsfw/model.rb, line 15
def predict(tensor)
  prediction = make_prediction(tensor)
  format_prediction(prediction)
end
reshape_tensor(tensor) click to toggle source
# File lib/nsfw/model.rb, line 29
def reshape_tensor(tensor)
  OnnxRuntime::Utils.reshape(tensor, [1, 224, 224, 3])
end
safe?(image) click to toggle source
# File lib/nsfw/model.rb, line 20
def safe?(image)
  prediction = predict(image.tensor)
  prediction["neutral"] >= SAFETY_THRESHOLD
end

Private Instance Methods

format_prediction(prediction) click to toggle source
# File lib/nsfw/model.rb, line 44
def format_prediction(prediction)
  results = prediction.fetch("Identity").first
  CATEGORIES.zip(results).sort{|a,b| b.last - a.last }.to_h
end
load_model!() click to toggle source
# File lib/nsfw/model.rb, line 35
def load_model!
  @model ||= OnnxRuntime::Model.new(MODEL_PATH)
end
make_prediction(tensor) click to toggle source
# File lib/nsfw/model.rb, line 39
def make_prediction(tensor)
  load_model! unless loaded?
  @model.predict({ "input" => reshape_tensor(tensor) })
end