Update README.md
#2
by
Jiqing
- opened
README.md
CHANGED
|
@@ -5,7 +5,9 @@ tags: []
|
|
| 5 |
|
| 6 |
# Model Card for Model ID
|
| 7 |
|
| 8 |
-
ProtST for binary localization
|
|
|
|
|
|
|
| 9 |
|
| 10 |
## Running script
|
| 11 |
```python
|
|
@@ -22,6 +24,9 @@ import torch
|
|
| 22 |
import logging
|
| 23 |
import datasets
|
| 24 |
import transformers
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
| 27 |
logger = logging.getLogger(__name__)
|
|
@@ -73,7 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
|
|
| 73 |
"lr": training_args.learning_rate * lr_ratio
|
| 74 |
},
|
| 75 |
]
|
| 76 |
-
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
|
|
|
| 77 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 78 |
|
| 79 |
return optimizer
|
|
@@ -98,7 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
|
|
| 98 |
|
| 99 |
|
| 100 |
if __name__ == "__main__":
|
| 101 |
-
device = torch.device("cpu")
|
|
|
|
| 102 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
| 103 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 104 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
@@ -108,8 +115,10 @@ if __name__ == "__main__":
|
|
| 108 |
'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
|
| 109 |
'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
|
| 110 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
| 111 |
-
'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
|
| 114 |
def tokenize_protein(example, tokenizer=None):
|
| 115 |
protein_seq = example["prot_seq"]
|
|
@@ -125,7 +134,8 @@ if __name__ == "__main__":
|
|
| 125 |
for split in ["train", "validation", "test"]:
|
| 126 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
| 127 |
|
| 128 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
|
|
| 129 |
|
| 130 |
transformers.utils.logging.set_verbosity_info()
|
| 131 |
log_level = training_args.get_process_log_level()
|
|
@@ -134,9 +144,16 @@ if __name__ == "__main__":
|
|
| 134 |
optimizer = create_optimizer(model)
|
| 135 |
scheduler = create_scheduler(training_args, optimizer)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# build trainer
|
| 138 |
-
trainer = Trainer(
|
|
|
|
| 139 |
model=model,
|
|
|
|
| 140 |
args=training_args,
|
| 141 |
train_dataset=raw_dataset["train"],
|
| 142 |
eval_dataset=raw_dataset["validation"],
|
|
|
|
| 5 |
|
| 6 |
# Model Card for Model ID
|
| 7 |
|
| 8 |
+
ProtST for binary localization.
|
| 9 |
+
|
| 10 |
+
The following script shows how to finetune ProtST on Gaudi.
|
| 11 |
|
| 12 |
## Running script
|
| 13 |
```python
|
|
|
|
| 24 |
import logging
|
| 25 |
import datasets
|
| 26 |
import transformers
|
| 27 |
+
+ import habana_frameworks.torch
|
| 28 |
+
+ from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
|
| 29 |
+
|
| 30 |
|
| 31 |
logging.basicConfig(level=logging.INFO)
|
| 32 |
logger = logging.getLogger(__name__)
|
|
|
|
| 78 |
"lr": training_args.learning_rate * lr_ratio
|
| 79 |
},
|
| 80 |
]
|
| 81 |
+
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
| 82 |
+
+ optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
|
| 83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 84 |
|
| 85 |
return optimizer
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
if __name__ == "__main__":
|
| 107 |
+
- device = torch.device("cpu")
|
| 108 |
+
+ device = torch.device("hpu")
|
| 109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
| 110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
|
|
| 115 |
'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
|
| 116 |
'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
|
| 117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
| 118 |
+
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
| 119 |
+
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
| 120 |
+
- training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
| 121 |
+
+ training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
| 122 |
|
| 123 |
def tokenize_protein(example, tokenizer=None):
|
| 124 |
protein_seq = example["prot_seq"]
|
|
|
|
| 134 |
for split in ["train", "validation", "test"]:
|
| 135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
| 136 |
|
| 137 |
+
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 138 |
+
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
|
| 139 |
|
| 140 |
transformers.utils.logging.set_verbosity_info()
|
| 141 |
log_level = training_args.get_process_log_level()
|
|
|
|
| 144 |
optimizer = create_optimizer(model)
|
| 145 |
scheduler = create_scheduler(training_args, optimizer)
|
| 146 |
|
| 147 |
+
+ gaudi_config = GaudiConfig()
|
| 148 |
+
+ gaudi_config.use_fused_adam = True
|
| 149 |
+
+ gaudi_config.use_fused_clip_norm =True
|
| 150 |
+
|
| 151 |
+
|
| 152 |
# build trainer
|
| 153 |
+
- trainer = Trainer(
|
| 154 |
+
+ trainer = GaudiTrainer(
|
| 155 |
model=model,
|
| 156 |
+
+ gaudi_config=gaudi_config,
|
| 157 |
args=training_args,
|
| 158 |
train_dataset=raw_dataset["train"],
|
| 159 |
eval_dataset=raw_dataset["validation"],
|