| import torch |
| import torch.nn.functional as F |
| from fvcore.nn import sigmoid_focal_loss_jit |
| from torch import nn |
|
|
| import torch.distributed as dist |
| from torch.distributed import get_world_size |
| from torchvision import ops |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
|
|
| def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight): |
| """ |
| Args: |
| gt_classes: a long tensor of shape R that contains the gt class label of each proposal. |
| num_fed_loss_classes: minimum number of classes to keep when calculating federated loss. |
| Will sample negative classes if number of unique gt_classes is smaller than this value. |
| num_classes: number of foreground classes |
| weight: probabilities used to sample negative classes |
| Returns: |
| Tensor: |
| classes to keep when calculating the federated loss, including both unique gt |
| classes and sampled negative classes. |
| """ |
| unique_gt_classes = torch.unique(gt_classes) |
| prob = unique_gt_classes.new_ones(num_classes + 1).float() |
| prob[-1] = 0 |
| if len(unique_gt_classes) < num_fed_loss_classes: |
| prob[:num_classes] = weight.float().clone() |
| prob[unique_gt_classes] = 0 |
| sampled_negative_classes = torch.multinomial( |
| prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False |
| ) |
| fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes]) |
| else: |
| fed_loss_classes = unique_gt_classes |
| return fed_loss_classes |
|
|
|
|
| class CriterionDynamicK(nn.Module): |
| """ This class computes the loss for DiffusionDet. |
| The process happens in two steps: |
| 1) we compute hungarian assignment between ground truth boxes and the outputs of the model |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) |
| """ |
|
|
| def __init__(self, config, num_classes, weight_dict): |
| """ Create the criterion. |
| Parameters: |
| num_classes: number of object categories, omitting the special no-object category |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. |
| """ |
| super().__init__() |
| self.config = config |
| self.num_classes = num_classes |
| self.matcher = HungarianMatcherDynamicK(config) |
| self.weight_dict = weight_dict |
| self.eos_coef = config.no_object_weight |
| self.use_focal = config.use_focal |
| self.use_fed_loss = config.use_fed_loss |
|
|
| if self.use_focal: |
| self.focal_loss_alpha = config.alpha |
| self.focal_loss_gamma = config.gamma |
|
|
| |
| def loss_labels(self, outputs, targets, indices): |
| """Classification loss (NLL) |
| targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] |
| """ |
| assert 'pred_logits' in outputs |
| src_logits = outputs['pred_logits'] |
| batch_size = len(targets) |
|
|
| |
| |
| target_classes = torch.full(src_logits.shape[:2], self.num_classes, |
| dtype=torch.int64, device=src_logits.device) |
| src_logits_list = [] |
| target_classes_o_list = [] |
| |
| for batch_idx in range(batch_size): |
| valid_query = indices[batch_idx][0] |
| gt_multi_idx = indices[batch_idx][1] |
| if len(gt_multi_idx) == 0: |
| continue |
| bz_src_logits = src_logits[batch_idx] |
| target_classes_o = targets[batch_idx]["labels"] |
| target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx] |
|
|
| src_logits_list.append(bz_src_logits[valid_query]) |
| target_classes_o_list.append(target_classes_o[gt_multi_idx]) |
|
|
| if self.use_focal or self.use_fed_loss: |
| num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1 |
|
|
| target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1], |
| dtype=src_logits.dtype, layout=src_logits.layout, |
| device=src_logits.device) |
| target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) |
|
|
| gt_classes = torch.argmax(target_classes_onehot, dim=-1) |
| target_classes_onehot = target_classes_onehot[:, :, :-1] |
|
|
| src_logits = src_logits.flatten(0, 1) |
| target_classes_onehot = target_classes_onehot.flatten(0, 1) |
| if self.use_focal: |
| cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha, |
| gamma=self.focal_loss_gamma, reduction="none") |
| else: |
| cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none") |
| if self.use_fed_loss: |
| K = self.num_classes |
| N = src_logits.shape[0] |
| fed_loss_classes = get_fed_loss_classes( |
| gt_classes, |
| num_fed_loss_classes=self.fed_loss_num_classes, |
| num_classes=K, |
| weight=self.fed_loss_cls_weights, |
| ) |
| fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1) |
| fed_loss_classes_mask[fed_loss_classes] = 1 |
| fed_loss_classes_mask = fed_loss_classes_mask[:K] |
| weight = fed_loss_classes_mask.view(1, K).expand(N, K).float() |
|
|
| loss_ce = torch.sum(cls_loss * weight) / num_boxes |
| else: |
| loss_ce = torch.sum(cls_loss) / num_boxes |
|
|
| losses = {'loss_ce': loss_ce} |
| else: |
| raise NotImplementedError |
|
|
| return losses |
|
|
| def loss_boxes(self, outputs, targets, indices): |
| """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss |
| targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] |
| The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. |
| """ |
| assert 'pred_boxes' in outputs |
| |
| src_boxes = outputs['pred_boxes'] |
|
|
| batch_size = len(targets) |
| pred_box_list = [] |
| pred_norm_box_list = [] |
| tgt_box_list = [] |
| tgt_box_xyxy_list = [] |
| for batch_idx in range(batch_size): |
| valid_query = indices[batch_idx][0] |
| gt_multi_idx = indices[batch_idx][1] |
| if len(gt_multi_idx) == 0: |
| continue |
| bz_image_whwh = targets[batch_idx]['image_size_xyxy'] |
| bz_src_boxes = src_boxes[batch_idx] |
| bz_target_boxes = targets[batch_idx]["boxes"] |
| bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] |
| pred_box_list.append(bz_src_boxes[valid_query]) |
| pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) |
| tgt_box_list.append(bz_target_boxes[gt_multi_idx]) |
| tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx]) |
|
|
| if len(pred_box_list) != 0: |
| src_boxes = torch.cat(pred_box_list) |
| src_boxes_norm = torch.cat(pred_norm_box_list) |
| target_boxes = torch.cat(tgt_box_list) |
| target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list) |
| num_boxes = src_boxes.shape[0] |
|
|
| losses = {} |
| |
| loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none') |
| losses['loss_bbox'] = loss_bbox.sum() / num_boxes |
|
|
| |
| loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy)) |
| losses['loss_giou'] = loss_giou.sum() / num_boxes |
| else: |
| losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0, |
| 'loss_giou': outputs['pred_boxes'].sum() * 0} |
|
|
| return losses |
|
|
| def get_loss(self, loss, outputs, targets, indices): |
| loss_map = { |
| 'labels': self.loss_labels, |
| 'boxes': self.loss_boxes, |
| } |
| assert loss in loss_map, f'do you really want to compute {loss} loss?' |
| return loss_map[loss](outputs, targets, indices) |
|
|
| def forward(self, outputs, targets): |
| """ This performs the loss computation. |
| Parameters: |
| outputs: dict of tensors, see the output specification of the model for the format |
| targets: list of dicts, such that len(targets) == batch_size. |
| The expected keys in each dict depends on the losses applied, see each loss' doc |
| """ |
| outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} |
|
|
| |
| indices, _ = self.matcher(outputs_without_aux, targets) |
|
|
| |
| losses = {} |
| for loss in ["labels", "boxes"]: |
| losses.update(self.get_loss(loss, outputs, targets, indices)) |
|
|
| |
| if 'aux_outputs' in outputs: |
| for i, aux_outputs in enumerate(outputs['aux_outputs']): |
| indices, _ = self.matcher(aux_outputs, targets) |
| for loss in ["labels", "boxes"]: |
| if loss == 'masks': |
| |
| continue |
|
|
| l_dict = self.get_loss(loss, aux_outputs, targets, indices) |
| l_dict = {k + f'_{i}': v for k, v in l_dict.items()} |
| losses.update(l_dict) |
|
|
| return losses |
|
|
|
|
| def get_in_boxes_info(boxes, target_gts): |
| xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') |
|
|
| anchor_center_x = boxes[:, 0].unsqueeze(1) |
| anchor_center_y = boxes[:, 1].unsqueeze(1) |
|
|
| |
| b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0) |
| b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0) |
| b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0) |
| b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0) |
| |
| is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) |
| is_in_boxes_all = is_in_boxes.sum(1) > 0 |
| |
| center_radius = 2.5 |
| |
| |
| b_l = anchor_center_x > ( |
| target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) |
| b_r = anchor_center_x < ( |
| target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) |
| b_t = anchor_center_y > ( |
| target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) |
| b_b = anchor_center_y < ( |
| target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) |
|
|
| is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) |
| is_in_centers_all = is_in_centers.sum(1) > 0 |
|
|
| is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all |
| is_in_boxes_and_center = (is_in_boxes & is_in_centers) |
|
|
| return is_in_boxes_anchor, is_in_boxes_and_center |
|
|
|
|
| class HungarianMatcherDynamicK(nn.Module): |
| """This class computes an assignment between the targets and the predictions of the network |
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, |
| there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions, |
| while the others are un-matched (and thus treated as non-objects). |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.use_focal = config.use_focal |
| self.use_fed_loss = config.use_fed_loss |
| self.cost_class = config.class_weight |
| self.cost_giou = config.giou_weight |
| self.cost_bbox = config.l1_weight |
| self.ota_k = config.ota_k |
|
|
| if self.use_focal: |
| self.focal_loss_alpha = config.alpha |
| self.focal_loss_gamma = config.gamma |
|
|
| assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0" |
|
|
| def forward(self, outputs, targets): |
| """ simOTA for detr""" |
| with torch.no_grad(): |
| bs, num_queries = outputs["pred_logits"].shape[:2] |
| |
| if self.use_focal or self.use_fed_loss: |
| out_prob = outputs["pred_logits"].sigmoid() |
| out_bbox = outputs["pred_boxes"] |
| else: |
| out_prob = outputs["pred_logits"].softmax(-1) |
| out_bbox = outputs["pred_boxes"] |
|
|
| indices = [] |
| matched_ids = [] |
| assert bs == len(targets) |
| for batch_idx in range(bs): |
| bz_boxes = out_bbox[batch_idx] |
| bz_out_prob = out_prob[batch_idx] |
| bz_tgt_ids = targets[batch_idx]["labels"] |
| num_insts = len(bz_tgt_ids) |
| if num_insts == 0: |
| non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0 |
| indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob)) |
| matched_qidx = torch.arange(0, 0).to(bz_out_prob) |
| indices.append(indices_batchi) |
| matched_ids.append(matched_qidx) |
| continue |
|
|
| bz_gtboxs = targets[batch_idx]['boxes'] |
| bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy'] |
| fg_mask, is_in_boxes_and_center = get_in_boxes_info( |
| ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), |
| ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') |
| ) |
|
|
| pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy) |
|
|
| |
| if self.use_focal: |
| alpha = self.focal_loss_alpha |
| gamma = self.focal_loss_gamma |
| neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log()) |
| pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log()) |
| cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] |
| elif self.use_fed_loss: |
| |
| neg_cost_class = (-(1 - bz_out_prob + 1e-8).log()) |
| pos_cost_class = (-(bz_out_prob + 1e-8).log()) |
| cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] |
| else: |
| cost_class = -bz_out_prob[:, bz_tgt_ids] |
|
|
| |
| |
| |
| |
|
|
| bz_image_size_out = targets[batch_idx]['image_size_xyxy'] |
| bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt'] |
|
|
| bz_out_bbox_ = bz_boxes / bz_image_size_out |
| bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt |
| cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1) |
|
|
| cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy) |
|
|
| |
| cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * ( |
| ~is_in_boxes_and_center) |
| |
| cost[~fg_mask] = cost[~fg_mask] + 10000.0 |
|
|
| |
| indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0]) |
|
|
| indices.append(indices_batchi) |
| matched_ids.append(matched_qidx) |
|
|
| return indices, matched_ids |
|
|
| def dynamic_k_matching(self, cost, pair_wise_ious, num_gt): |
| matching_matrix = torch.zeros_like(cost) |
| ious_in_boxes_matrix = pair_wise_ious |
| n_candidate_k = self.ota_k |
|
|
| |
| topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0) |
| dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) |
|
|
| for gt_idx in range(num_gt): |
| _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) |
| matching_matrix[:, gt_idx][pos_idx] = 1.0 |
|
|
| del topk_ious, dynamic_ks, pos_idx |
|
|
| anchor_matching_gt = matching_matrix.sum(1) |
|
|
| if (anchor_matching_gt > 1).sum() > 0: |
| _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) |
| matching_matrix[anchor_matching_gt > 1] *= 0 |
| matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 |
|
|
| while (matching_matrix.sum(0) == 0).any(): |
| num_zero_gt = (matching_matrix.sum(0) == 0).sum() |
| matched_query_id = matching_matrix.sum(1) > 0 |
| cost[matched_query_id] += 100000.0 |
| unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) |
| for gt_idx in unmatch_id: |
| pos_idx = torch.argmin(cost[:, gt_idx]) |
| matching_matrix[:, gt_idx][pos_idx] = 1.0 |
| if (matching_matrix.sum(1) > 1).sum() > 0: |
| _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], |
| dim=1) |
| matching_matrix[anchor_matching_gt > 1] *= 0 |
| matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 |
|
|
| assert not (matching_matrix.sum(0) == 0).any() |
| selected_query = matching_matrix.sum(1) > 0 |
| gt_indices = matching_matrix[selected_query].max(1)[1] |
| assert selected_query.sum() == len(gt_indices) |
|
|
| cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf') |
| matched_query_id = torch.min(cost, dim=0)[1] |
|
|
| return (selected_query, gt_indices), matched_query_id |
|
|