mlr_callback_set.unfreeze {mlr3torch} | R Documentation |
Unfreezing Weights Callback
Description
Unfreeze some weights (parameters of the network) after some number of steps or epochs.
Super class
mlr3torch::CallbackSet
-> CallbackSetUnfreeze
Methods
Public methods
Inherited methods
Method new()
Creates a new instance of this R6 class.
Usage
CallbackSetUnfreeze$new(starting_weights, unfreeze)
Arguments
starting_weights
(
Select
)
ASelect
denoting the weights that are trainable from the start.unfreeze
(
data.table
)
Adata.table
with a columnweights
(a list column ofSelect
s) and a columnepoch
orbatch
. The selector indicates which parameters to unfreeze, while theepoch
orbatch
column indicates when to do so.
Method on_begin()
Sets the starting weights
Usage
CallbackSetUnfreeze$on_begin()
Method on_epoch_begin()
Unfreezes weights if the training is at the correct epoch
Usage
CallbackSetUnfreeze$on_epoch_begin()
Method on_batch_begin()
Unfreezes weights if the training is at the correct batch
Usage
CallbackSetUnfreeze$on_batch_begin()
Method clone()
The objects of this class are cloneable with this method.
Usage
CallbackSetUnfreeze$clone(deep = FALSE)
Arguments
deep
Whether to make a deep clone.
See Also
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_callback_set.tb
,
mlr_context_torch
,
t_clbk()
,
torch_callback()
Examples
task = tsk("iris")
cb = t_clbk("unfreeze")
mlp = lrn("classif.mlp", callbacks = cb,
cb.unfreeze.starting_weights = select_invert(
select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
),
cb.unfreeze.unfreeze = data.table(
epoch = c(2, 5),
weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
),
epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
)
mlp$train(task)