from functools import lru_cache import torch import torch.nn.functional as F from loguru import logger from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModel DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B' LEGACY_MODELS = [ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'sentence-transformers/all-mpnet-base-v2', 'sentence-transformers/all-MiniLM-L12-v2', 'cyclone/simcse-chinese-roberta-wwm-ext', 'bert-base-chinese', 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', ] list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL] class SBert: def __init__(self, path): logger.info(f'Start loading {self.__class__} from {path} ...') self.model = SentenceTransformer(path, device=DEVICE) logger.info(f'Load {self.__class__} from {path} ...') @lru_cache(maxsize=10000) def __call__(self, x) -> torch.Tensor: y = self.model.encode(x, convert_to_tensor=True) return y class ModelWithPooling: def __init__(self, path): logger.info(f'Start loading {self.__class__} from {path} ...') self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModel.from_pretrained(path) logger.info(f'Load {self.__class__} from {path} ...') @lru_cache(maxsize=100) @torch.no_grad() def __call__(self, text: str, pooling='mean'): inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") outputs = self.model(**inputs, output_hidden_states=True) if pooling == 'cls': o = outputs.last_hidden_state[:, 0] # [b, h] elif pooling == 'pooler': o = outputs.pooler_output # [b, h] elif pooling in ['mean', 'last-avg']: last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s] o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] elif pooling == 'first-last-avg': first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s] last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s] first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h] o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h] else: raise Exception(f'Unknown pooling {pooling}') o = o.squeeze(0) return o class Qwen3Embedding: def __init__(self, path): logger.info(f'Start loading {self.__class__} from {path} ...') self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') self.model = AutoModel.from_pretrained(path) self.model.to(DEVICE) self.model.eval() logger.info(f'Load {self.__class__} from {path} ...') @staticmethod def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return last_hidden_states[:, -1] sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[ torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths, ] @lru_cache(maxsize=100) @torch.no_grad() def __call__(self, text: str, pooling='mean'): inputs = self.tokenizer( text, padding=True, truncation=True, max_length=8192, return_tensors='pt', ) inputs = {key: value.to(DEVICE) for key, value in inputs.items()} outputs = self.model(**inputs) embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings.squeeze(0) @lru_cache(maxsize=8) def get_embedding_model(model_name: str): if model_name == QWEN3_EMBEDDING_MODEL: return Qwen3Embedding(model_name) return ModelWithPooling(model_name) def test_sbert(): m = SBert('bert-base-chinese') o = m('hello') print(o.size()) assert o.size() == (768,) def test_hf_model(): m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') o = m('hello', pooling='cls') print(o.size()) assert o.size() == (768,)