| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput |
| from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config |
|
|
|
|
| class GenericAdaptor(AdaptorBase): |
| def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None): |
| super().__init__() |
|
|
| if state is not None: |
| self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.') |
| self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.') |
| else: |
| assert mlp_config is not None, "Config must not be None if state is None" |
|
|
| self.head_mlp = create_mlp_from_config( |
| main_config.mlp_version, |
| mlp_config["summary"]["input_dim"], |
| mlp_config["summary"]["hidden_dim"], |
| mlp_config["summary"]["output_dim"], |
| mlp_config["summary"]["num_inner"], |
| ) |
| self.feat_mlp = create_mlp_from_config( |
| main_config.mlp_version, |
| mlp_config["feature"]["input_dim"], |
| mlp_config["feature"]["hidden_dim"], |
| mlp_config["feature"]["output_dim"], |
| mlp_config["feature"]["num_inner"], |
| ) |
|
|
| def forward(self, input: AdaptorInput) -> RadioOutput: |
| |
| first_param = next(self.parameters()) |
| summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype) |
| feat = self.feat_mlp(input.features.to(dtype=first_param.dtype)).to(dtype=input.features.dtype) |
|
|
| if input.feature_fmt == 'NCHW': |
| feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size, input.images.shape[-1] // input.patch_size, feat.shape[2]) |
| .permute(0, 3, 1, 2) |
| ) |
|
|
| return RadioOutput(summary, feat) |
|
|