Instructions to use daiweichen/pal-b-large-opt-350m with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use daiweichen/pal-b-large-opt-350m with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "summarization" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("summarization", model="daiweichen/pal-b-large-opt-350m", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("daiweichen/pal-b-large-opt-350m", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .connector import Connector | |
| from .projector import Projector | |
| from .tensor_initializer import TensorInitializer | |
| from .custom_sfx import CustomSoftMax | |
| import numpy as np | |
| import warnings | |
| from typing import Literal | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class UserLearner(nn.Module): | |
| k: int # the number of groups | |
| llm: nn.Module | |
| projectors: list[Projector] | |
| u_id_set: set | |
| softmax: nn.Module | |
| is_partition: bool | |
| def __init__( | |
| self, | |
| k: int, | |
| llm: nn.Module, | |
| projectors: list[Projector], | |
| softmax: nn.Module, | |
| is_partition: bool=False, | |
| ): | |
| super().__init__() | |
| self.k = k | |
| self.llm = llm | |
| self.softmax = softmax | |
| # init user_id registration table and user weights dictionary | |
| self.u_id_set = set() | |
| self.W = nn.ParameterDict() | |
| self.tmp_store_user_ideal_points = None | |
| # register all k projectors in the moduledict | |
| assert len(projectors) == k, f"The num of projectors should match up with num of groups: {k} != {len(projectors)}" | |
| self.projectors = nn.ModuleDict() | |
| for i in range(k): | |
| self.projectors[str(i)] = projectors[i] | |
| self.is_partition = is_partition | |
| def init_weight(self, u_ids:list, reinit:bool=False): | |
| for u_id in u_ids: | |
| if u_id not in self.u_id_set or reinit: | |
| self.W[u_id] = nn.Parameter( | |
| torch.randn((self.k), dtype=torch.float32), | |
| requires_grad=True, | |
| ).to(next(self.projectors[str(0)].parameters()).device) | |
| self.u_id_set.add(u_id) | |
| else: | |
| logger.warning('๐ wait? same user?') | |
| def get_sfx_w(self, u_ids:list): | |
| w = torch.stack([self.W[key] for key in u_ids], dim=0) # (bs, k) | |
| w = self.softmax(w) | |
| return w | |
| def get_hardmax_w(self, u_ids:list): | |
| w = torch.stack([self.W[key] for key in u_ids], dim=0) | |
| w = F.one_hot(w.argmax(dim=1), num_classes=self.k).float() # (bs, k) | |
| return w | |
| def infer_gk(self, prompt_tokens, rm_cached=None): | |
| ''' | |
| prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor} | |
| If you want to activate rm_cached, please pass in the rm_cached dict or empty dict. | |
| ''' | |
| input_ids = prompt_tokens['input_ids'] | |
| attention_mask = prompt_tokens['attention_mask'] | |
| if rm_cached is None: | |
| embeds = self.llm( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ).last_hidden_state | |
| else: | |
| res = self.llm( | |
| input_ids=input_ids[:, -1:], | |
| # attention_mask=attention_mask, | |
| past_key_values=rm_cached["user_learner"], | |
| use_cache=True | |
| ) | |
| rm_cached["user_learner"] = res.past_key_values | |
| embeds = res.last_hidden_state | |
| # embeds shape: (bs, seq_len, hid_dim) | |
| shape = embeds.shape | |
| # only last hidden state start (only use the last token of the prompt) | |
| embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim) | |
| embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim) | |
| # only last hidden state end | |
| # logger.critical("using only last hidden state of prompt tokens") | |
| embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim) | |
| # g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim) | |
| logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1) | |
| if rm_cached is None: | |
| return logits | |
| else: | |
| return logits, rm_cached # (bs, k, seq_len, hidden_size) | |
| def return_user_ideal_points(self): | |
| if self.tmp_store_user_ideal_points == None: | |
| raise ValueError('No user ideal points stored') | |
| return self.tmp_store_user_ideal_points | |
| def forward(self, uid, prompt_tokens, rm_cached=None): # only pass the prompt tokens | |
| ''' | |
| prompt_tokens: {'input_ids': torch.tensor, 'attention_mask': torch.tensor} | |
| ''' | |
| if rm_cached is None: | |
| prompt_logits = self.infer_gk(prompt_tokens) | |
| else: | |
| prompt_logits, rm_cached = self.infer_gk(prompt_tokens, rm_cached) | |
| bs = prompt_tokens['input_ids'].size(0) | |
| w = self.get_sfx_w([uid]*bs) | |
| # assert sum(mix_weight) == 1 | |
| # w = self.softmax(mix_weight.repeat(bs, 1)) | |
| # w = mix_weight.repeat(bs, 1) | |
| # logger.info(f"{w=}") | |
| # logger.info(f"{w.shape=}") | |
| w = w.unsqueeze(-1).unsqueeze(-1) | |
| y_hat = (w * prompt_logits).sum(dim=1) | |
| self.tmp_store_user_ideal_points = y_hat | |
| if rm_cached is None: | |
| return y_hat | |
| else: | |
| return y_hat, rm_cached | |
| def eval(self): | |
| super().eval() | |
| if self.is_partition: | |
| warnings.warn("๐ค UserPromptLearner(Partition version) is in eval mode: argmax") | |
| self.is_argmax = True | |
| else: | |
| warnings.warn("๐ค UserPromptLearner(Mixture version) is in eval mode: sfx") | |
| self.is_argmax = False | |
| def train(self, mode: bool = True): | |
| super().train(mode) | |
| if mode: | |
| if self.is_partition: | |
| warnings.warn("๐ค UserPromptLearner(Partition version) is in train mode: sfx") | |
| self.is_argmax = False | |
| else: | |
| warnings.warn("๐ค UserPromptLearner(Mixture version) is in train mode: sfx") | |
| self.is_argmax = False | |
| else: | |
| if self.is_partition: | |
| warnings.warn("๐ค UserPromptLearner(Partition version) is in eval mode: argmax") | |
| self.is_argmax = True | |
| else: | |
| warnings.warn("๐ค UserPromptLearner(Mixture version) is in eval mode: sfx") | |
| self.is_argmax = False |