| |
| import inspect |
| import json |
| import shutil |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| from fuse_clip.fuse_clip_arch import FuseCLIP |
| from open_clip import get_input_dtype, SimpleTokenizer |
|
|
|
|
| class FuseLIP(FuseCLIP, PyTorchModelHubMixin): |
| """FuseLIP with save_pretrained / from_pretrained / push_to_hub.""" |
|
|
| |
| def _save_pretrained(self, save_directory: Path, **kwargs): |
| save_directory = Path(save_directory) |
| save_directory.mkdir(parents=True, exist_ok=True) |
|
|
| torch.save(self.state_dict(), save_directory / "pytorch_model.bin") |
| (save_directory / "config.json").write_text( |
| json.dumps(self.get_config(), indent=2) |
| ) |
| |
| |
| |
| |
| |
|
|
| |
| source_path = Path(inspect.getfile(FuseLIP)) |
| shutil.copy(source_path, save_directory / "fuse_clip_hub.py") |
|
|
| |
| @classmethod |
| def _from_pretrained(cls, save_directory: Path, **kwargs): |
|
|
| cfg = json.loads(Path(save_directory, "config.json").read_text()) |
|
|
| tokenizer = SimpleTokenizer(context_length=cfg["context_length"]) |
| tokenizer.pad_token_id = 0 |
|
|
| if cfg["mlm_probability"] > 0: |
| MASK_TOKEN = "[MASK]" |
| if MASK_TOKEN not in tokenizer.encoder: |
| |
| mask_token_id = max(tokenizer.encoder.values()) + 1 |
|
|
| |
| tokenizer.encoder[MASK_TOKEN] = mask_token_id |
| tokenizer.decoder[mask_token_id] = MASK_TOKEN |
|
|
| tokenizer.all_special_ids.append(mask_token_id) |
| tokenizer.mask_token = mask_token_id |
| tokenizer.vocab_size += 1 |
|
|
| print(f"Added `[MASK]` token with ID {mask_token_id}") |
| else: |
| mask_token_id = tokenizer.encoder[MASK_TOKEN] |
| print(f"`[MASK]` token already exists with ID {mask_token_id}") |
|
|
|
|
| cfg["image_tokenizer_path"] = cfg["image_tokenizer"] |
| cfg["init_logit_scale"] = np.log(10) |
| cfg["init_logit_bias"] = -10 |
| cfg["input_dtype"] = get_input_dtype("fp32") |
| del cfg["text_config"] |
| del cfg["image_tokenizer"] |
| del cfg["context_length"] |
|
|
| model = cls(**cfg, **kwargs) |
| state = torch.load( |
| Path(save_directory, "pytorch_model.bin"), |
| map_location="cpu" |
| ) |
| model.load_state_dict(state, strict=True) |
| return model |
|
|