| from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments |
| from dataset import MyDataset |
| from data_collator import MyDataCollator |
|
|
| |
| model_name = 'bert-base-uncased' |
| batch_size = 16 |
| num_epochs = 3 |
|
|
| |
| train_data = MyDataset('train.csv', AutoTokenizer.from_pretrained(model_name)) |
| val_data = MyDataset('val.csv', AutoTokenizer.from_pretrained(model_name)) |
|
|
| |
| data_collator = MyDataCollator(AutoTokenizer.from_pretrained(model_name)) |
|
|
| |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir='./results', |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| evaluation_strategy='epoch', |
| save_total_limit=2, |
| save_steps=500, |
| load_best_model_at_end=True, |
| metric_for_best_model='accuracy', |
| greater_is_better=True, |
| save_on_each_node=True, |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_data, |
| eval_dataset=val_data, |
| compute_metrics=lambda pred: {'accuracy': torch.sum(torch.argmax(pred.label_ids, dim=1) == torch.argmax(pred.predictions, dim=1))}, |
| data_collator=data_collator, |
| ) |
|
|
| |
| trainer.train() |