| | import torch |
| | from transformers import AutoModel |
| |
|
| |
|
| | def build_text_encoder(config): |
| | if config.model_type == "mpnet": |
| | model = AutoModel.from_pretrained(config.pretrained_name_or_path) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | return model |
| |
|
| |
|
| | |
| | def mean_pooling(model_output, attention_mask): |
| | token_embeddings = model_output[ |
| | 0 |
| | ] |
| | input_mask_expanded = ( |
| | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | ) |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| | input_mask_expanded.sum(1), min=1e-9 |
| | ) |
| |
|