| import copy |
| import math |
| from dataclasses import astuple |
|
|
| import torch |
| from torch import nn |
| from torch.nn.modules.transformer import _get_activation_fn |
| from torchvision.ops import RoIAlign |
|
|
| _DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16) |
|
|
| def convert_boxes_to_pooler_format(bboxes): |
| bs, num_proposals = bboxes.shape[:2] |
| sizes = torch.full((bs,), num_proposals).to(bboxes.device) |
| aggregated_bboxes = bboxes.view(bs * num_proposals, -1) |
| indices = torch.repeat_interleave( |
| torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes |
| ) |
| return torch.cat([indices[:, None], aggregated_bboxes], dim=1) |
|
|
|
|
| def assign_boxes_to_levels( |
| bboxes, |
| min_level, |
| max_level, |
| canonical_box_size, |
| canonical_level, |
| ): |
| aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1) |
| area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1]) |
| box_sizes = torch.sqrt(area) |
| |
| level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)) |
| |
| |
| level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) |
| return level_assignments.to(torch.int64) - min_level |
|
|
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
| embeddings = math.log(10000) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| embeddings = time[:, None] * embeddings[None, :] |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| return embeddings |
|
|
|
|
| class HeadDynamicK(nn.Module): |
| def __init__(self, config, roi_input_shape): |
| super().__init__() |
| num_classes = config.num_labels |
|
|
| ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes) |
| self.num_head = config.num_heads |
| self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)]) |
| self.return_intermediate = config.deep_supervision |
|
|
| |
| self.hidden_dim = config.hidden_dim |
| time_dim = self.hidden_dim * 4 |
| self.time_mlp = nn.Sequential( |
| SinusoidalPositionEmbeddings(self.hidden_dim), |
| nn.Linear(self.hidden_dim, time_dim), |
| nn.GELU(), |
| nn.Linear(time_dim, time_dim), |
| ) |
|
|
| |
| self.use_focal = config.use_focal |
| self.use_fed_loss = config.use_fed_loss |
| self.num_classes = num_classes |
| if self.use_focal or self.use_fed_loss: |
| prior_prob = config.prior_prob |
| self.bias_value = -math.log((1 - prior_prob) / prior_prob) |
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| |
| if self.use_focal or self.use_fed_loss: |
| if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1: |
| nn.init.constant_(p, self.bias_value) |
|
|
|
|
| def forward(self, features, bboxes, t): |
| |
| time = self.time_mlp(t) |
|
|
| inter_class_logits = [] |
| inter_pred_bboxes = [] |
|
|
| bs = len(features[0]) |
|
|
| class_logits, pred_bboxes = None, None |
| for head_idx, ddet_head in enumerate(self.head_series): |
| class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time) |
| if self.return_intermediate: |
| inter_class_logits.append(class_logits) |
| inter_pred_bboxes.append(pred_bboxes) |
| bboxes = pred_bboxes.detach() |
|
|
| if self.return_intermediate: |
| return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes) |
|
|
| return class_logits[None], pred_bboxes[None] |
|
|
|
|
| class DynamicConv(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.hidden_dim = config.hidden_dim |
| self.dim_dynamic = config.dim_dynamic |
| self.num_dynamic = config.num_dynamic |
| self.num_params = self.hidden_dim * self.dim_dynamic |
| self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params) |
|
|
| self.norm1 = nn.LayerNorm(self.dim_dynamic) |
| self.norm2 = nn.LayerNorm(self.hidden_dim) |
|
|
| self.activation = nn.ReLU(inplace=True) |
|
|
| pooler_resolution = config.pooler_resolution |
| num_output = self.hidden_dim * pooler_resolution ** 2 |
| self.out_layer = nn.Linear(num_output, self.hidden_dim) |
| self.norm3 = nn.LayerNorm(self.hidden_dim) |
|
|
|
|
| def forward(self, pro_features, roi_features): |
| features = roi_features.permute(1, 0, 2) |
| parameters = self.dynamic_layer(pro_features).permute(1, 0, 2) |
|
|
| param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic) |
| param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim) |
|
|
| features = torch.bmm(features, param1) |
| features = self.norm1(features) |
| features = self.activation(features) |
|
|
| features = torch.bmm(features, param2) |
| features = self.norm2(features) |
| features = self.activation(features) |
|
|
| features = features.flatten(1) |
| features = self.out_layer(features) |
| features = self.norm3(features) |
| features = self.activation(features) |
|
|
| return features |
|
|
|
|
| class DiffusionDetHead(nn.Module): |
| def __init__(self, config, roi_input_shape, num_classes): |
| super().__init__() |
|
|
| dim_feedforward = config.dim_feedforward |
| nhead = config.num_attn_heads |
| dropout = config.dropout |
| activation = config.activation |
| in_features = config.roi_head_in_features |
| pooler_resolution = config.pooler_resolution |
| pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features) |
| sampling_ratio = config.sampling_ratio |
|
|
| self.hidden_dim = config.hidden_dim |
|
|
| self.pooler = ROIPooler( |
| output_size=pooler_resolution, |
| scales=pooler_scales, |
| sampling_ratio=sampling_ratio, |
| ) |
|
|
| |
| self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout) |
| self.inst_interact = DynamicConv(config) |
|
|
| self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim) |
|
|
| self.norm1 = nn.LayerNorm(self.hidden_dim) |
| self.norm2 = nn.LayerNorm(self.hidden_dim) |
| self.norm3 = nn.LayerNorm(self.hidden_dim) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| self.dropout3 = nn.Dropout(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
|
|
| |
| self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2)) |
|
|
| |
| num_cls = config.num_cls |
| cls_module = list() |
| for _ in range(num_cls): |
| cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False)) |
| cls_module.append(nn.LayerNorm(self.hidden_dim)) |
| cls_module.append(nn.ReLU(inplace=True)) |
| self.cls_module = nn.ModuleList(cls_module) |
|
|
| |
| num_reg = config.num_reg |
| reg_module = list() |
| for _ in range(num_reg): |
| reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False)) |
| reg_module.append(nn.LayerNorm(self.hidden_dim)) |
| reg_module.append(nn.ReLU(inplace=True)) |
| self.reg_module = nn.ModuleList(reg_module) |
|
|
| |
| self.use_focal = config.use_focal |
| self.use_fed_loss = config.use_fed_loss |
| if self.use_focal or self.use_fed_loss: |
| self.class_logits = nn.Linear(self.hidden_dim, num_classes) |
| else: |
| self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1) |
| self.bboxes_delta = nn.Linear(self.hidden_dim, 4) |
| self.scale_clamp = _DEFAULT_SCALE_CLAMP |
| self.bbox_weights = (2.0, 2.0, 1.0, 1.0) |
|
|
| def forward(self, features, bboxes, time_emb): |
| bs, num_proposals = bboxes.shape[:2] |
|
|
| |
| roi_features = self.pooler(features, bboxes) |
|
|
| pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1) |
|
|
| roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1) |
|
|
| |
| pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2) |
| pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0] |
| pro_features = pro_features + self.dropout1(pro_features2) |
| pro_features = self.norm1(pro_features) |
|
|
| |
| pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals, |
| self.hidden_dim) |
| pro_features2 = self.inst_interact(pro_features, roi_features) |
| pro_features = pro_features + self.dropout2(pro_features2) |
| obj_features = self.norm2(pro_features) |
|
|
| |
| obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features)))) |
| obj_features = obj_features + self.dropout3(obj_features2) |
| obj_features = self.norm3(obj_features) |
|
|
| fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1) |
|
|
| scale_shift = self.block_time_mlp(time_emb) |
| scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0) |
| scale, shift = scale_shift.chunk(2, dim=1) |
| fc_feature = fc_feature * (scale + 1) + shift |
|
|
| cls_feature = fc_feature.clone() |
| reg_feature = fc_feature.clone() |
| for cls_layer in self.cls_module: |
| cls_feature = cls_layer(cls_feature) |
| for reg_layer in self.reg_module: |
| reg_feature = reg_layer(reg_feature) |
| class_logits = self.class_logits(cls_feature) |
| bboxes_deltas = self.bboxes_delta(reg_feature) |
| pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4)) |
|
|
| return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features |
|
|
| def apply_deltas(self, deltas, boxes): |
| """ |
| Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. |
| |
| Args: |
| deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. |
| deltas[i] represents k potentially different class-specific |
| box transformations for the single box boxes[i]. |
| boxes (Tensor): boxes to transform, of shape (N, 4) |
| """ |
| boxes = boxes.to(deltas.dtype) |
|
|
| widths = boxes[:, 2] - boxes[:, 0] |
| heights = boxes[:, 3] - boxes[:, 1] |
| ctr_x = boxes[:, 0] + 0.5 * widths |
| ctr_y = boxes[:, 1] + 0.5 * heights |
|
|
| wx, wy, ww, wh = self.bbox_weights |
| dx = deltas[:, 0::4] / wx |
| dy = deltas[:, 1::4] / wy |
| dw = deltas[:, 2::4] / ww |
| dh = deltas[:, 3::4] / wh |
|
|
| |
| dw = torch.clamp(dw, max=self.scale_clamp) |
| dh = torch.clamp(dh, max=self.scale_clamp) |
|
|
| pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] |
| pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] |
| pred_w = torch.exp(dw) * widths[:, None] |
| pred_h = torch.exp(dh) * heights[:, None] |
|
|
| pred_boxes = torch.zeros_like(deltas) |
| pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w |
| pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h |
| pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w |
| pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h |
|
|
| return pred_boxes |
|
|
|
|
| class ROIPooler(nn.Module): |
| """ |
| Region of interest feature map pooler that supports pooling from one or more |
| feature maps. |
| """ |
|
|
| def __init__( |
| self, |
| output_size, |
| scales, |
| sampling_ratio, |
| canonical_box_size=224, |
| canonical_level=4, |
| ): |
| super().__init__() |
|
|
| min_level = -(math.log2(scales[0])) |
| max_level = -(math.log2(scales[-1])) |
|
|
| if isinstance(output_size, int): |
| output_size = (output_size, output_size) |
| assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int) |
| assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level)) |
| assert (len(scales) == max_level - min_level + 1) |
| assert 0 <= min_level <= max_level |
| assert canonical_box_size > 0 |
|
|
| self.output_size = output_size |
| self.min_level = int(min_level) |
| self.max_level = int(max_level) |
| self.canonical_level = canonical_level |
| self.canonical_box_size = canonical_box_size |
| self.level_poolers = nn.ModuleList( |
| RoIAlign( |
| output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True |
| ) |
| for scale in scales |
| ) |
|
|
| def forward(self, x, bboxes): |
| num_level_assignments = len(self.level_poolers) |
| assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0) |
|
|
| pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes) |
|
|
| if num_level_assignments == 1: |
| return self.level_poolers[0](x[0], pooler_fmt_boxes) |
|
|
| level_assignments = assign_boxes_to_levels( |
| bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level |
| ) |
|
|
| batches = pooler_fmt_boxes.shape[0] |
| channels = x[0].shape[1] |
| output_size = self.output_size[0] |
| sizes = (batches, channels, output_size, output_size) |
|
|
| output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device) |
|
|
| for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)): |
| inds = (level_assignments == level).nonzero(as_tuple=True)[0] |
| pooler_fmt_boxes_level = pooler_fmt_boxes[inds] |
| |
| output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level)) |
|
|
| return output |
|
|