class NSFW::Model
Constants
- CATEGORIES
- MODEL_PATH
- SAFETY_THRESHOLD
Attributes
model[R]
Public Class Methods
new(lazy: false)
click to toggle source
# File lib/nsfw/model.rb, line 11 def initialize(lazy: false) load_model! unless lazy end
Public Instance Methods
loaded?()
click to toggle source
# File lib/nsfw/model.rb, line 25 def loaded? !@model.nil? end
predict(tensor)
click to toggle source
# File lib/nsfw/model.rb, line 15 def predict(tensor) prediction = make_prediction(tensor) format_prediction(prediction) end
reshape_tensor(tensor)
click to toggle source
# File lib/nsfw/model.rb, line 29 def reshape_tensor(tensor) OnnxRuntime::Utils.reshape(tensor, [1, 224, 224, 3]) end
safe?(image)
click to toggle source
# File lib/nsfw/model.rb, line 20 def safe?(image) prediction = predict(image.tensor) prediction["neutral"] >= SAFETY_THRESHOLD end
Private Instance Methods
format_prediction(prediction)
click to toggle source
# File lib/nsfw/model.rb, line 44 def format_prediction(prediction) results = prediction.fetch("Identity").first CATEGORIES.zip(results).sort{|a,b| b.last - a.last }.to_h end
load_model!()
click to toggle source
# File lib/nsfw/model.rb, line 35 def load_model! @model ||= OnnxRuntime::Model.new(MODEL_PATH) end
make_prediction(tensor)
click to toggle source
# File lib/nsfw/model.rb, line 39 def make_prediction(tensor) load_model! unless loaded? @model.predict({ "input" => reshape_tensor(tensor) }) end