Distill logits¶
Use teacher’s output distribution as true label
Add temperature in teacher’s output to make relations more apparent
class SoftTarget(nn.Module):
# Distilling the Knowledge in a Neural Network: https://arxiv.org/pdf/1503.02531.pdf
# The only change for KD from original training is to implement a new loss function
def __init__(self, T):
super(SoftTarget, self).__init__()
self.T = T
def forward(self, out_s, out_t):
loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
F.softmax(out_t/self.T, dim=1),
reduction='batchmean') * self.T * self.T
return loss
Distill¶
Full code optimizer/loss.py
Setup¶
Initializations
t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0])
h = model.hyp # hyperparameters
Choose the loss function type
DboxLoss = nn.MSELoss(reduction="none")
if dist_loss == "l2":
DclsLoss = nn.MSELoss(reduction="none")
elif dist_loss == "kl":
DclsLoss = nn.KLDivLoss(reduction="none")
else:
DclsLoss = nn.BCEWithLogitsLoss(reduction="none")
DobjLoss = nn.MSELoss(reduction="none")
Iterative throught the predictions
for i, pi in enumerate(p): # layer index, layer predictions
t_pi = t_p[i]
t_obj_scale = t_pi[..., 4].sigmoid()
Calculate Bbox errors
# BBox
b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)
if not reg_norm:
t_lbox += torch.mean(DboxLoss(pi[..., :4],
t_pi[..., :4]) * b_obj_scale)
else:
wh_norm_scale = reg_norm[i].unsqueeze(
0).unsqueeze(-2).unsqueeze(-2)
t_lbox += torch.mean(DboxLoss(pi[..., :2].sigmoid(),
t_pi[..., :2].sigmoid()) * b_obj_scale)
t_lbox += torch.mean(DboxLoss(pi[..., 2:4].sigmoid(),
t_pi[..., 2:4].sigmoid() * wh_norm_scale) * b_obj_scale)
Calculate Classes errors
# Class
if model.nc > 1: # cls loss (only if multiple classes)
c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1,
1, 1, 1, model.nc)
if dist_loss == "kl":
kl_loss = DclsLoss(F.log_softmax(pi[..., 5:]/T, dim=-1),
F.softmax(t_pi[..., 5:]/T, dim=-1)) * (T * T)
t_lcls += torch.mean(kl_loss * c_obj_scale)
else:
t_lcls += torch.mean(DclsLoss(pi[..., 5:],
t_pi[..., 5:]) * c_obj_scale)
Scale the loss with distillation ratio
t_lbox *= h['box'] * distill_ratio
t_lobj *= h['obj'] * distill_ratio
t_lcls *= h['cls'] * distill_ratio
Run the code¶
$ python train.py --data coco.yaml --epochs 101 --weights "./checkpoint/enasv2_L1_yolov5s.pt" --t_weights "./checkpoint/yolov5m.pt"