Distill features¶
The paper implemented in this experiement is Distilling-Object-Detectors
Goal: Not learning the whole feature map (too complicated for the student). Learn near object anchor locations’ features.

Modify Model¶
Add function to make center anchors (fixed type)
def make_center_anchors(anchors_wh, grid_size=80, device='cpu'):
grid_arange = torch.arange(grid_size)
xx, yy = torch.meshgrid(grid_arange, grid_arange) # + 0.5 # grid center, [fmsize*fmsize,2]
xy = torch.cat((torch.unsqueeze(xx, -1), torch.unsqueeze(yy, -1)), -1) + 0.5
wh = torch.tensor(anchors_wh)
xy = xy.view(grid_size, grid_size, 1, 2).expand(grid_size, grid_size, 5, 2).type(torch.float32) # centor
wh = wh.view(1, 1, 5, 2).expand(grid_size, grid_size, 5, 2).type(torch.float32) # w, h
center_anchors = torch.cat([xy, wh], dim=3).to(device)
# cy cx w h
"""
center_anchors[0][0]
tensor([[ 0.5000, 0.5000, 1.3221, 1.7314],
[ 0.5000, 0.5000, 3.1927, 4.0094],
[ 0.5000, 0.5000, 5.0559, 8.0989],
[ 0.5000, 0.5000, 9.4711, 4.8405],
[ 0.5000, 0.5000, 11.2364, 10.0071]], device='cuda:0')
center_anchors[0][1]
tensor([[ 1.5000, 0.5000, 1.3221, 1.7314],
[ 1.5000, 0.5000, 3.1927, 4.0094],
[ 1.5000, 0.5000, 5.0559, 8.0989],
[ 1.5000, 0.5000, 9.4711, 4.8405],
[ 1.5000, 0.5000, 11.2364, 10.0071]], device='cuda:0')
center_anchors[1][0]
tensor([[ 0.5000, 1.5000, 1.3221, 1.7314],
[ 0.5000, 1.5000, 3.1927, 4.0094],
[ 0.5000, 1.5000, 5.0559, 8.0989],
[ 0.5000, 1.5000, 9.4711, 4.8405],
[ 0.5000, 1.5000, 11.2364, 10.0071]], device='cuda:0')
pytorch view has reverse index
"""
return center_anchors
Add the function to get intermediate feature to our model
def _forward_once(self, x, profile=False, visualize=False, target=None):
y, dt = [], [] # outputs
cnt = 0
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
if isinstance(m, Concat):
cnt += 1
if cnt == 2:
feature = x
if target is not None:
return x, feature
return x
Add the function to get imitation mask to our model
def _get_imitation_mask(self, x, targets, iou_factor=0.5):
"""
gt_box: (B, K, 4) [x_min, y_min, x_max, y_max]
"""
self.anchors = [(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053),
(11.2364, 10.0071)]
self.num_anchors = len(self.anchors)
out_size = x.size(2)
batch_size = x.size(0)
device = targets.device
mask_batch = torch.zeros([batch_size, out_size, out_size])
if not len(targets):
return mask_batch
gt_boxes = [[] for i in range(batch_size)]
for i in range(len(targets)):
gt_boxes[int(targets[i, 0].data)] += [targets[i, 2:].clone().detach().unsqueeze(0)]
max_num = 0
for i in range(batch_size):
max_num = max(max_num, len(gt_boxes[i]))
if len(gt_boxes[i]) == 0:
gt_boxes[i] = torch.zeros((1, 4), device=device)
else:
gt_boxes[i] = torch.cat(gt_boxes[i], 0)
for i in range(batch_size):
# print(gt_boxes[i].device)
if max_num - gt_boxes[i].size(0):
gt_boxes[i] = torch.cat((gt_boxes[i], torch.zeros((max_num - gt_boxes[i].size(0), 4), device=device)), 0)
gt_boxes[i] = gt_boxes[i].unsqueeze(0)
gt_boxes = torch.cat(gt_boxes, 0)
gt_boxes *= out_size
center_anchors = make_center_anchors(anchors_wh=self.anchors, grid_size=out_size, device=device)
anchors = center_to_corner(center_anchors).view(-1, 4) # (N, 4)
gt_boxes = center_to_corner(gt_boxes)
mask_batch = torch.zeros([batch_size, out_size, out_size], device=device)
for i in range(batch_size):
num_obj = gt_boxes[i].size(0)
if not num_obj:
continue
IOU_map = find_jaccard_overlap(anchors, gt_boxes[i], 0).view(out_size, out_size, self.num_anchors, num_obj)
max_iou, _ = IOU_map.view(-1, num_obj).max(dim=0)
mask_img = torch.zeros([out_size, out_size], dtype=torch.int64, requires_grad=False).type_as(x)
threshold = max_iou * iou_factor
for k in range(num_obj):
mask_per_gt = torch.sum(IOU_map[:, :, :, k] > threshold[k], dim=2)
mask_img += mask_per_gt
mask_img += mask_img
mask_batch[i] = mask_img
mask_batch = mask_batch.clamp(0, 1)
return mask_batch # (B, h, w)
Change the forward function of our model to fit all cases (distill or not)
def forward(self, x, augment=False, profile=False, visualize=False, target=None):
if augment:
return self._forward_augment(x) # augmented inference, None
if target != None:
preds, features = self._forward_once(x, profile, visualize, target)
mask = self._get_imitation_mask(features, target).unsqueeze(1)
return preds, features, mask
return self._forward_once(x, profile, visualize) # single-scale inference, train
New Loss Calculation¶
Add a function that returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
# Intersection area
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
# IoU
iou = inter / union
if CIoU or DIoU or GIoU:
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU:
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou # IoU
Calculate the Loss Between Teacher’s And Student’s Imitation_mask
def imitation_loss(teacher, student, mask):
if student is None or teacher is None:
return 0
# print(teacher.shape, student.shape, mask.shape)
diff = torch.pow(student - teacher, 2) * mask
diff = diff.sum() / mask.sum() / 2
return diff
Modify Computeloss Class, full code please refer to here
def __call__(self, p, targets, teacher=None, student=None, mask=None): # predictions, targets, model
lmask = imitation_loss(teacher, student, mask) * 0.01
Run the Code¶
python train.py --data coco.yaml --epochs 101 --weights "original_model" --ft_weights "teacher_model"