1.8. RegularizationΒΆ
torchmfbd v0.1 allows the user to define regularization terms to be added to the loss function by inherting from torchmfbd.Regularization
. The smooth and IUWT regularization terms
can be added via the configuration file. We will add more options in the future but you can define your own regularization
terms via an external function.
As an example, let us assume that we want to add a new regularization term to the loss function that penalizes values of the object away from zero. To this end, we define the following class:
class MyRegularization(torchmfbd.Regularization):
def __init__(self, lambda_reg, variable, value):
super(MyRegularization, self).__init__('external', lambda_reg, variable)
self.variable = variable
self.lambda_reg = lambda_reg
self.value = value
def __call__(self, x):
# Add your regularization term here
n_o = len(x)
loss = 0.0
for i in range(n_o):
loss += self.lambda_reg * torch.sum((x[i] - self.value)**2)
return loss
Now we instantiate the torchmfbd.Deconvolution
class and add the external regularization:
deconv = torchmfbd.Deconvolution('qs_8542_kl_patches.yaml')
myregularization = MyRegularization(lambda_reg=0.01, variable='object', value=0.0)
deconv.add_external_regularization(myregularization)
The external regularization term will be added to the loss function and will be optimized along with the other terms.