Instructions to use daiweichen/pal-b-large-opt-350m with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use daiweichen/pal-b-large-opt-350m with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "summarization" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("summarization", model="daiweichen/pal-b-large-opt-350m", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("daiweichen/pal-b-large-opt-350m", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| # -*-coding:utf-8 -*- | |
| ''' | |
| @Desc: This is the implementation of PAL-B | |
| ''' | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoConfig | |
| from .connector import Connector | |
| from .tensor_initializer import TensorInitializer | |
| from .custom_sfx import CustomSoftMax | |
| from .itemLearner import ItemLearner | |
| from .userLearner import UserLearner | |
| from collections import defaultdict | |
| from typing import Literal, Optional, Tuple | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class BasePrefLearner(nn.Module): | |
| def __init__( | |
| self, | |
| d_hid: int, | |
| d_pref: int, | |
| k: int, | |
| llm_name: str, | |
| pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"], | |
| proj_arch: str, | |
| initializer_type: Literal["gaussian"], | |
| is_expectation_norm_init: bool, # the tensor initialization parameters | |
| sfx_type: Literal["gumbel_softmax", "softmax"], | |
| sfx_temperature: float, | |
| is_temperature_learnable: bool, | |
| is_gumbel_hard: Optional[bool]=None, | |
| is_partition: bool=False, | |
| seed: int=42, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.pref_learner_type = pref_learner_type | |
| self.is_temperature_learnable = is_temperature_learnable | |
| # init all necessary modules | |
| model_config = AutoConfig.from_pretrained(llm_name) | |
| self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config) | |
| self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init) | |
| self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) | |
| self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)] | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| self.softmax_w = CustomSoftMax(sfx_type=sfx_type, | |
| temperature=sfx_temperature, | |
| is_temperature_learnable=is_temperature_learnable, | |
| is_gumbel_hard=is_gumbel_hard) | |
| self.item_learner = ItemLearner( | |
| llm = self.llm, | |
| projector=self.projector_f, | |
| ) | |
| self.is_partition = is_partition | |
| self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition) | |
| logger.critical('🛑 Remember to call update_trainable_params() after the model is initialized.') | |
| def update_trainable_params(self, fix_modules: Tuple[str,...]=()): | |
| # capture params | |
| self.trainable_params = defaultdict(list) | |
| if "llm" not in fix_modules: | |
| self.trainable_params["llm"] = self.llm.parameters() | |
| else: | |
| self.llm.eval() | |
| if "itemLearnerProjector" not in fix_modules: | |
| self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters()) | |
| if "userLearnerProjector" not in fix_modules: | |
| self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters())) | |
| if "W" not in fix_modules: | |
| self.trainable_params["W"] = self.user_learner.W.parameters() | |
| if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules: | |
| self.trainable_params["logit_scale"] = self.logit_scale | |
| if self.is_temperature_learnable and "temperature" not in fix_modules: | |
| self.trainable_params["temperature"] = self.softmax_w.temperature | |
| def map_to_pref_embedding_space(self, x, rm_cached=None): | |
| # ( | |
| # uid, | |
| # { | |
| # 'input_ids': prompt_input_ids,\ | |
| # 'attention_mask': prompt_attention_mask, | |
| # },\ | |
| # { | |
| # 'input_ids': eval_input_ids,\ | |
| # 'attention_mask': eval_attention_mask,\ | |
| # }) | |
| uid, prompt, items = x | |
| if rm_cached is None: | |
| items_prime = self.item_learner(items) | |
| prompt_prime = self.user_learner(uid, prompt) | |
| return items_prime, prompt_prime | |
| else: | |
| items_prime, rm_cached = self.item_learner(items, rm_cached) | |
| prompt_prime, rm_cached = self.user_learner(uid, prompt, rm_cached) | |
| return items_prime, prompt_prime, rm_cached | |
| class PrefLearner(BasePrefLearner): # <f(x),f(u)> | |
| def __init__(self,*args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def specify_user_ids(self, uid): # personalize the model for a specific user | |
| self.uid = uid | |
| def forward(self, x, rm_cached=None): | |
| assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model" | |
| prompt, items = x | |
| if rm_cached is None: | |
| items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items)) | |
| else: | |
| items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached) | |
| # logger.critical(f"{items_prime[0]=}") | |
| # logger.critical(f"{prompt_prime[0]=}") | |
| # logger.critical(f"{items_prime.shape=}") | |
| # logger.critical(f"{prompt_prime.shape=}") | |
| if self.pref_learner_type == 'angle': | |
| # NOTICE: here we implement the "last token only" version of PAL-B | |
| prompt_last_prime = prompt_prime[:, -1, :] | |
| prompt_last_prime = prompt_last_prime.unsqueeze(1) | |
| prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True) | |
| items_last_prime = items_prime[:, -1, :] | |
| items_last_prime = items_last_prime.unsqueeze(1) | |
| items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True) | |
| logit_scale = self.logit_scale.exp() | |
| clamped_logit_scale = torch.clamp(logit_scale, max=100) | |
| # logger.critical(f"{prompt_last_prime.shape=}") | |
| # logger.critical(f"{items_last_prime.shape=}") | |
| sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length) | |
| if rm_cached is None: | |
| return sim_score | |
| else: | |
| return sim_score, rm_cached | |
| else: | |
| raise NotImplementedError | |