dejanseo commited on
Commit
7a907f5
·
verified ·
1 Parent(s): b76832a

Upload train.py

Browse files
Files changed (1) hide show
  1. cross-entropy/train.py +346 -0
cross-entropy/train.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import argparse
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--bump", type=int, default=0, help="Extra epochs to train (resumes from last checkpoint)")
9
+ args = parser.parse_args()
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ AutoModelForTokenClassification,
13
+ TrainingArguments,
14
+ Trainer,
15
+ DataCollatorForTokenClassification,
16
+ )
17
+ from datasets import Dataset
18
+ import wandb
19
+
20
+ MODEL_NAME = "microsoft/deberta-v3-large"
21
+ TRAIN_FILE = "train.json"
22
+ CACHE_FILE = "chunks.cache.json"
23
+ MAX_LEN = 512
24
+ STRIDE = 128
25
+ LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2}
26
+ ID2LABEL = {v: k for k, v in LABEL2ID.items()}
27
+
28
+
29
+ def parse_annotated(annotated):
30
+ """Parse 'title[SEP]text with [SPAN]...[/SPAN]' into title, plain_text, and char offsets."""
31
+ title, body = annotated.split("[SEP]", 1)
32
+
33
+ # Extract span offsets from body
34
+ spans = []
35
+ plain = ""
36
+ i = 0
37
+ while i < len(body):
38
+ if body[i:i+6] == "[SPAN]":
39
+ start = len(plain)
40
+ i += 6
41
+ while i < len(body) and body[i:i+7] != "[/SPAN]":
42
+ plain += body[i]
43
+ i += 1
44
+ end = len(plain)
45
+ spans.append((start, end))
46
+ if body[i:i+7] == "[/SPAN]":
47
+ i += 7
48
+ else:
49
+ plain += body[i]
50
+ i += 1
51
+
52
+ return title.strip(), plain, spans
53
+
54
+
55
+ def chunk_with_title(title_ids, text_ids, text_labels, max_len, stride):
56
+ """Create overlapping chunks, each prefixed with title tokens."""
57
+ # Reserve space: [CLS] + title + [SEP] + ... + [SEP]
58
+ title_budget = len(title_ids) + 3 # CLS, SEP after title, SEP at end
59
+ text_budget = max_len - title_budget
60
+
61
+ if text_budget <= 0:
62
+ return []
63
+
64
+ chunks = []
65
+ start = 0
66
+
67
+ while start < len(text_ids):
68
+ end = min(start + text_budget, len(text_ids))
69
+ chunk_text_ids = text_ids[start:end]
70
+ chunk_labels = list(text_labels[start:end])
71
+
72
+ # Fix BIO boundary: if chunk starts mid-span, first span token must be B-SPAN
73
+ for j, lbl in enumerate(chunk_labels):
74
+ if lbl == LABEL2ID["I-SPAN"]:
75
+ chunk_labels[j] = LABEL2ID["B-SPAN"]
76
+ break
77
+ elif lbl != -100:
78
+ break
79
+
80
+ # Build full sequence: [CLS] title [SEP] text_chunk [SEP]
81
+ input_ids = [tokenizer.cls_token_id] + title_ids + [tokenizer.sep_token_id] + chunk_text_ids + [tokenizer.sep_token_id]
82
+ labels = [-100] + [-100] * len(title_ids) + [-100] + chunk_labels + [-100]
83
+ attention_mask = [1] * len(input_ids)
84
+
85
+ # Pad to max_len
86
+ pad_len = max_len - len(input_ids)
87
+ if pad_len > 0:
88
+ input_ids += [tokenizer.pad_token_id] * pad_len
89
+ labels += [-100] * pad_len
90
+ attention_mask += [0] * pad_len
91
+
92
+ chunks.append({
93
+ "input_ids": input_ids,
94
+ "attention_mask": attention_mask,
95
+ "labels": labels,
96
+ })
97
+
98
+ if end >= len(text_ids):
99
+ break
100
+ start += stride
101
+
102
+ return chunks
103
+
104
+
105
+ print("Loading tokenizer...")
106
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
107
+
108
+ import os
109
+ if os.path.exists(CACHE_FILE):
110
+ print(f"Loading cached chunks from {CACHE_FILE}...")
111
+ with open(CACHE_FILE, "r", encoding="utf-8") as f:
112
+ all_chunks = json.load(f)
113
+ print(f"Loaded {len(all_chunks):,} chunks from cache")
114
+ else:
115
+ print(f"Loading {TRAIN_FILE}...")
116
+ with open(TRAIN_FILE, "r", encoding="utf-8") as f:
117
+ raw_data = json.load(f)
118
+
119
+ print(f"Parsing and tokenizing {len(raw_data):,} articles...")
120
+ all_chunks = []
121
+
122
+ for i, item in enumerate(raw_data):
123
+ title, plain_text, span_offsets = parse_annotated(item["annotated"])
124
+
125
+ # Tokenize title (no special tokens)
126
+ title_enc = tokenizer(title, add_special_tokens=False)
127
+ title_ids = title_enc["input_ids"]
128
+
129
+ # Tokenize text with offset mapping
130
+ text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True)
131
+ text_ids = text_enc["input_ids"]
132
+ text_offsets_map = text_enc["offset_mapping"]
133
+
134
+ # Build token-level BIO labels for text tokens
135
+ text_labels = []
136
+ for tok_idx, (tok_start, tok_end) in enumerate(text_offsets_map):
137
+ if tok_start == 0 and tok_end == 0:
138
+ text_labels.append(-100)
139
+ continue
140
+
141
+ label = LABEL2ID["O"]
142
+ for span_start, span_end in span_offsets:
143
+ if tok_start >= span_start and tok_end <= span_end:
144
+ if tok_start == span_start:
145
+ label = LABEL2ID["B-SPAN"]
146
+ else:
147
+ label = LABEL2ID["I-SPAN"]
148
+ break
149
+ text_labels.append(label)
150
+
151
+ # Chunk
152
+ chunks = chunk_with_title(title_ids, text_ids, text_labels, MAX_LEN, STRIDE)
153
+ all_chunks.extend(chunks)
154
+
155
+ if (i + 1) % 2000 == 0:
156
+ print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}")
157
+
158
+ print(f"Total chunks: {len(all_chunks):,}")
159
+ print(f"Saving cache to {CACHE_FILE}...")
160
+ with open(CACHE_FILE, "w", encoding="utf-8") as f:
161
+ json.dump(all_chunks, f)
162
+ print("Cache saved.")
163
+
164
+ # Verify label distribution
165
+ all_labels_flat = []
166
+ for c in all_chunks:
167
+ all_labels_flat.extend([l for l in c["labels"] if l >= 0])
168
+ from collections import Counter
169
+ dist = Counter(all_labels_flat)
170
+ total_labeled = sum(dist.values())
171
+ print(f"Label distribution:")
172
+ for label_id, count in sorted(dist.items()):
173
+ print(f" {ID2LABEL[label_id]}: {count:,} ({count/total_labeled*100:.2f}%)")
174
+
175
+ # Split train/val
176
+ print("Splitting 95/5 train/val...")
177
+ train_chunks, val_chunks = train_test_split(all_chunks, test_size=0.05, random_state=42)
178
+ print(f"Train: {len(train_chunks):,} | Val: {len(val_chunks):,}")
179
+
180
+ train_ds = Dataset.from_list(train_chunks)
181
+ val_ds = Dataset.from_list(val_chunks)
182
+
183
+ # Model
184
+ print("Loading model...")
185
+ model = AutoModelForTokenClassification.from_pretrained(
186
+ MODEL_NAME,
187
+ num_labels=len(LABEL2ID),
188
+ id2label=ID2LABEL,
189
+ label2id=LABEL2ID,
190
+ )
191
+ model = model.float() # DeBERTa-v3 stores weights in FP16 natively; cast to FP32 for stable optimizer updates
192
+
193
+ data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=False)
194
+
195
+
196
+ def extract_spans_from_bio(seq):
197
+ """Extract contiguous spans from a BIO label sequence. Returns list of (start, end) tuples."""
198
+ spans = []
199
+ start = None
200
+ for i, label in enumerate(seq):
201
+ if label == LABEL2ID["B-SPAN"]:
202
+ if start is not None:
203
+ spans.append((start, i))
204
+ start = i
205
+ elif label == LABEL2ID["I-SPAN"]:
206
+ if start is None:
207
+ start = i # treat orphan I as B
208
+ else:
209
+ if start is not None:
210
+ spans.append((start, i))
211
+ start = None
212
+ if start is not None:
213
+ spans.append((start, len(seq)))
214
+ return spans
215
+
216
+
217
+ def compute_metrics(eval_pred):
218
+ logits, labels = eval_pred
219
+ preds = np.argmax(logits, axis=-1)
220
+
221
+ # Token-level per-class metrics
222
+ mask = labels.flatten() >= 0
223
+ flat_labels = labels.flatten()[mask]
224
+ flat_preds = preds.flatten()[mask]
225
+
226
+ results = {}
227
+ for label_name, label_id in LABEL2ID.items():
228
+ tp = ((flat_preds == label_id) & (flat_labels == label_id)).sum()
229
+ fp = ((flat_preds == label_id) & (flat_labels != label_id)).sum()
230
+ fn = ((flat_preds != label_id) & (flat_labels == label_id)).sum()
231
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
232
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
233
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
234
+ results[f"{label_name}_precision"] = float(precision)
235
+ results[f"{label_name}_recall"] = float(recall)
236
+ results[f"{label_name}_f1"] = float(f1)
237
+
238
+ # Entity-level span F1 (overlap-based)
239
+ total_tp = 0
240
+ total_pred = 0
241
+ total_true = 0
242
+
243
+ for i in range(len(labels)):
244
+ # Build valid label/pred sequences (skip -100)
245
+ valid_labels = []
246
+ valid_preds = []
247
+ for j in range(len(labels[i])):
248
+ if labels[i][j] >= 0:
249
+ valid_labels.append(labels[i][j])
250
+ valid_preds.append(preds[i][j])
251
+
252
+ pred_spans = extract_spans_from_bio(valid_preds)
253
+ true_spans = extract_spans_from_bio(valid_labels)
254
+
255
+ total_pred += len(pred_spans)
256
+ total_true += len(true_spans)
257
+
258
+ # Match: pred span overlaps >= 50% with a true span (and vice versa)
259
+ matched_true = set()
260
+ for ps, pe in pred_spans:
261
+ for idx, (ts, te) in enumerate(true_spans):
262
+ if idx in matched_true:
263
+ continue
264
+ overlap = max(0, min(pe, te) - max(ps, ts))
265
+ pred_len = pe - ps
266
+ true_len = te - ts
267
+ if pred_len > 0 and true_len > 0:
268
+ if overlap / pred_len >= 0.5 and overlap / true_len >= 0.5:
269
+ total_tp += 1
270
+ matched_true.add(idx)
271
+ break
272
+
273
+ entity_precision = total_tp / total_pred if total_pred > 0 else 0
274
+ entity_recall = total_tp / total_true if total_true > 0 else 0
275
+ entity_f1 = 2 * entity_precision * entity_recall / (entity_precision + entity_recall) if (entity_precision + entity_recall) > 0 else 0
276
+ results["entity_precision"] = float(entity_precision)
277
+ results["entity_recall"] = float(entity_recall)
278
+ results["entity_f1"] = float(entity_f1)
279
+
280
+ # Console report
281
+ total = len(flat_preds)
282
+ print(f"\n{'='*60}")
283
+ print(f" EVAL — Token-level ({total:,} tokens)")
284
+ print(f" {'Class':<10} {'Prec':>8} {'Rec':>8} {'F1':>8} | {'Pred':>8} {'True':>8}")
285
+ print(f" {'-'*54}")
286
+ for label_name, label_id in LABEL2ID.items():
287
+ p = results[f"{label_name}_precision"]
288
+ r = results[f"{label_name}_recall"]
289
+ f = results[f"{label_name}_f1"]
290
+ pred_count = (flat_preds == label_id).sum()
291
+ true_count = (flat_labels == label_id).sum()
292
+ print(f" {label_name:<10} {p:>8.4f} {r:>8.4f} {f:>8.4f} | {pred_count:>8,} {true_count:>8,}")
293
+ print(f" {'-'*54}")
294
+ print(f" Entity-level: P={entity_precision:.4f} R={entity_recall:.4f} F1={entity_f1:.4f} ({total_tp}/{total_pred} pred, {total_true} true)")
295
+ print(f"{'='*60}\n")
296
+
297
+ return results
298
+
299
+
300
+ resume = args.bump > 0
301
+ total_epochs = 1 + args.bump
302
+
303
+ wandb.init(project="span-extractor", name=f"deberta-v3-large-ce{f'-bump{args.bump}' if resume else ''}")
304
+
305
+ training_args = TrainingArguments(
306
+ output_dir="./span_model_ce",
307
+ num_train_epochs=total_epochs,
308
+ per_device_train_batch_size=4,
309
+ per_device_eval_batch_size=8,
310
+ gradient_accumulation_steps=4,
311
+ learning_rate=2e-5,
312
+ weight_decay=0.01,
313
+ warmup_ratio=0.1,
314
+ bf16=True,
315
+ logging_steps=1,
316
+ eval_strategy="steps",
317
+ eval_steps=500,
318
+ save_strategy="steps",
319
+ save_steps=500,
320
+ save_total_limit=3,
321
+ load_best_model_at_end=True,
322
+ metric_for_best_model="entity_f1",
323
+ greater_is_better=True,
324
+ dataloader_num_workers=0,
325
+ report_to="wandb",
326
+ remove_unused_columns=False,
327
+ )
328
+
329
+ trainer = Trainer(
330
+ model=model,
331
+ args=training_args,
332
+ train_dataset=train_ds,
333
+ eval_dataset=val_ds,
334
+ data_collator=data_collator,
335
+ compute_metrics=compute_metrics,
336
+ )
337
+
338
+ print(f"Training... (epochs={total_epochs}, resume={resume})")
339
+ trainer.train(resume_from_checkpoint=resume)
340
+
341
+ print("Saving final model...")
342
+ trainer.save_model("./span_model_ce/final")
343
+ tokenizer.save_pretrained("./span_model_ce/final")
344
+
345
+ wandb.finish()
346
+ print("Done.")