diyclassics Claude Opus 4.6 (1M context) commited on
Commit
f04d50f
·
1 Parent(s): 2c07f6c

refactor: extract shared case study utils and move data to tracked paths

Browse files

- Extract shared BertForSequenceLabeling, get_batches, word_to_subtokens
into tests/case_study_utils.py to reduce duplication across test files
- Move WSD and infilling data from .claude/reference/ (gitignored) to
data/case_studies/ so they ship with the HF repo
- Update conftest.py default model path to latincy/latin-bert
- Add scripts/benchmark.py (model-agnostic benchmark runner)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

data/case_studies/infilling/emendation_filtered.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/case_studies/wsd/latin.sense.data ADDED
The diff for this file is too large to render. See raw diff
 
tests/case_study_utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for Bamman & Burns (2020) case study tests.
2
+
3
+ Provides the subword-to-word transform matrix approach used by all four
4
+ case studies: POS tagging, WSD, infilling, and contextual nearest neighbors.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Constants
16
+ # ---------------------------------------------------------------------------
17
+ BERT_DIM = 768
18
+ BATCH_SIZE = 32
19
+ DROPOUT_RATE = 0.25
20
+
21
+ # Special tokens that should not go through subword encoding
22
+ SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
23
+
24
+ # Data paths (relative to repo root)
25
+ REPO_ROOT = Path(__file__).resolve().parent.parent
26
+ DATA_DIR = REPO_ROOT / "data"
27
+ CASE_STUDY_DIR = DATA_DIR / "case_studies"
28
+ WSD_DATA_PATH = CASE_STUDY_DIR / "wsd" / "latin.sense.data"
29
+ INFILLING_DATA_PATH = CASE_STUDY_DIR / "infilling" / "emendation_filtered.txt"
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Tokenization helpers
34
+ # ---------------------------------------------------------------------------
35
+ def word_to_subtokens(tokenizer, word):
36
+ """Get subtoken strings for a single word.
37
+
38
+ Special tokens ([CLS], [SEP], etc.) are returned as-is.
39
+ Regular words are tokenized through the subword pipeline,
40
+ matching the original LatinTokenizer.tokenize() behavior.
41
+ """
42
+ if word in SPECIAL_TOKENS:
43
+ return [word]
44
+ return tokenizer.tokenize(word)
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Batching with transform matrices
49
+ # ---------------------------------------------------------------------------
50
+ def get_batches(tokenizer, sentences, max_batch, has_labels=True):
51
+ """Tokenize and batch sentences with subword-to-word transform matrices.
52
+
53
+ Each word is tokenized individually (matching original behavior).
54
+ The transform matrix averages subword representations back to
55
+ word-level representations.
56
+
57
+ sentences: list of sentences, where each sentence is a list of items.
58
+ If has_labels=True, each item is [word, label, ...] (list/tuple).
59
+ If has_labels=False, each item is a word string.
60
+
61
+ Returns:
62
+ If has_labels: (data, masks, labels, transforms, ordering)
63
+ If not: (data, masks, transforms, ordering)
64
+ """
65
+ all_data = []
66
+ all_masks = []
67
+ all_labels = [] if has_labels else None
68
+ all_transforms = []
69
+
70
+ for sentence in sentences:
71
+ tok_ids = []
72
+ input_mask = []
73
+ labels = [] if has_labels else None
74
+ transform = []
75
+
76
+ # First pass: get subtokens for each word
77
+ all_toks = []
78
+ n = 0
79
+ for item in sentence:
80
+ word = item[0] if has_labels else item
81
+ toks = word_to_subtokens(tokenizer, word)
82
+ all_toks.append(toks)
83
+ n += len(toks)
84
+
85
+ # Second pass: build transform matrix and collect IDs
86
+ cur = 0
87
+ for idx, item in enumerate(sentence):
88
+ toks = all_toks[idx]
89
+ ind = list(np.zeros(n))
90
+ for j in range(cur, cur + len(toks)):
91
+ ind[j] = 1.0 / len(toks)
92
+ cur += len(toks)
93
+ transform.append(ind)
94
+ tok_ids.extend(tokenizer.convert_tokens_to_ids(toks))
95
+ input_mask.extend(np.ones(len(toks)))
96
+ if has_labels:
97
+ labels.append(int(item[1]))
98
+
99
+ all_data.append(tok_ids)
100
+ all_masks.append(input_mask)
101
+ if has_labels:
102
+ all_labels.append(labels)
103
+ all_transforms.append(transform)
104
+
105
+ lengths = np.array([len(l) for l in all_data])
106
+ ordering = np.argsort(lengths)
107
+
108
+ ordered_data = [None] * len(all_data)
109
+ ordered_masks = [None] * len(all_data)
110
+ ordered_labels = [None] * len(all_data) if has_labels else None
111
+ ordered_transforms = [None] * len(all_data)
112
+
113
+ for i, ind in enumerate(ordering):
114
+ ordered_data[i] = all_data[ind]
115
+ ordered_masks[i] = all_masks[ind]
116
+ if has_labels:
117
+ ordered_labels[i] = all_labels[ind]
118
+ ordered_transforms[i] = all_transforms[ind]
119
+
120
+ batched_data = []
121
+ batched_mask = []
122
+ batched_labels = [] if has_labels else None
123
+ batched_transforms = []
124
+
125
+ i = 0
126
+ current_batch = max_batch
127
+
128
+ while i < len(ordered_data):
129
+ bd = ordered_data[i:i + current_batch]
130
+ bm = ordered_masks[i:i + current_batch]
131
+ bl = ordered_labels[i:i + current_batch] if has_labels else None
132
+ bt = ordered_transforms[i:i + current_batch]
133
+
134
+ ml = max(len(s) for s in bd)
135
+ max_words = max(len(t) for t in bt)
136
+
137
+ for j in range(len(bd)):
138
+ blen = len(bd[j])
139
+ for _k in range(blen, ml):
140
+ bd[j].append(0)
141
+ bm[j].append(0)
142
+ for z in range(len(bt[j])):
143
+ bt[j][z].append(0)
144
+ if has_labels:
145
+ blab = len(bl[j])
146
+ for _k in range(blab, max_words):
147
+ bl[j].append(-100)
148
+ for _k in range(len(bt[j]), max_words):
149
+ bt[j].append(np.zeros(ml))
150
+
151
+ batched_data.append(torch.LongTensor(bd))
152
+ batched_mask.append(torch.FloatTensor(bm))
153
+ if has_labels:
154
+ batched_labels.append(torch.LongTensor(bl))
155
+ batched_transforms.append(torch.FloatTensor(bt))
156
+
157
+ i += current_batch
158
+ if ml > 100:
159
+ current_batch = 12
160
+ if ml > 200:
161
+ current_batch = 6
162
+
163
+ if has_labels:
164
+ return batched_data, batched_mask, batched_labels, batched_transforms, ordering
165
+ return batched_data, batched_mask, batched_transforms, ordering
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # Sequence labeling model (used by POS and WSD)
170
+ # ---------------------------------------------------------------------------
171
+ class BertForSequenceLabeling(nn.Module):
172
+ """BERT + linear classifier for sequence labeling.
173
+
174
+ Used by POS tagging and WSD case studies. The encoder is frozen
175
+ and a linear head is trained on top.
176
+ """
177
+
178
+ def __init__(self, tokenizer, bert_model, freeze_bert=False,
179
+ num_labels=2, hidden_size=BERT_DIM):
180
+ super().__init__()
181
+ self.tokenizer = tokenizer
182
+ self.num_labels = num_labels
183
+ self.bert = bert_model
184
+ self.bert.eval()
185
+ if freeze_bert:
186
+ for param in self.bert.parameters():
187
+ param.requires_grad = False
188
+ self.dropout = nn.Dropout(DROPOUT_RATE)
189
+ self.classifier = nn.Linear(hidden_size, num_labels)
190
+
191
+ def forward(self, input_ids, attention_mask=None, transforms=None,
192
+ labels=None):
193
+ device = input_ids.device
194
+ if attention_mask is not None:
195
+ attention_mask = attention_mask.to(device)
196
+ if transforms is not None:
197
+ transforms = transforms.to(device)
198
+ if labels is not None:
199
+ labels = labels.to(device)
200
+
201
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
202
+ sequence_output = outputs[0]
203
+ out = torch.matmul(transforms, sequence_output)
204
+ logits = self.classifier(out)
205
+
206
+ if labels is not None:
207
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
208
+ return loss_fct(
209
+ logits.view(-1, self.num_labels), labels.view(-1)
210
+ )
211
+ return logits
212
+
213
+ def get_batches(self, sentences, max_batch):
214
+ """Tokenize and batch with subword-to-word transform matrices.
215
+
216
+ Delegates to the module-level get_batches() function.
217
+ """
218
+ return get_batches(self.tokenizer, sentences, max_batch,
219
+ has_labels=True)
tests/conftest.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import pytest
4
 
5
- DEFAULT_MODEL_PATH = "/tmp/latin-bert-hub"
6
 
7
 
8
  def pytest_addoption(parser):
 
2
 
3
  import pytest
4
 
5
+ DEFAULT_MODEL_PATH = "latincy/latin-bert"
6
 
7
 
8
  def pytest_addoption(parser):
tests/test_contextual_nn.py CHANGED
@@ -25,8 +25,13 @@ import torch
25
  from torch import nn
26
  from transformers import AutoTokenizer, BertModel
27
 
28
- BERT_DIM = 768
29
- BATCH_SIZE = 32
 
 
 
 
 
30
 
31
 
32
  def _get_device():
@@ -37,11 +42,8 @@ def _get_device():
37
  return torch.device("mps")
38
  return torch.device("cpu")
39
 
40
- # Special tokens that should not go through subword encoding
41
- _SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
42
 
43
  # Data paths
44
- DATA_DIR = Path(__file__).parent.parent / "data"
45
  CORPUS_TEXT_DIR = DATA_DIR / "latin_library_text"
46
  CORPUS_BERT_DIR = DATA_DIR / "latin_library_bert"
47
  CORPUS_ARCHIVE = DATA_DIR / "latin_library_text.tar.gz"
@@ -49,143 +51,28 @@ CORPUS_ARCHIVE = DATA_DIR / "latin_library_text.tar.gz"
49
  # Google Drive download URL for Latin Library texts
50
  CORPUS_DOWNLOAD_ID = "1GRe3eFmQBDdF1kIT9T75aPTdquaf8Z8s"
51
 
52
-
53
- # ── Shared helpers ──────────────────────────────────────────────────────
54
-
55
-
56
- def _word_to_subtokens(tokenizer, word):
57
- """Get subtoken strings for a single word.
58
-
59
- Special tokens ([CLS], [SEP], etc.) are returned as-is.
60
- Regular words are tokenized through the subword pipeline.
61
- """
62
- if word in _SPECIAL_TOKENS:
63
- return [word]
64
- return tokenizer.tokenize(word)
65
-
66
-
67
- def _get_batches(tokenizer, sentences, max_batch):
68
- """Tokenize and batch sentences with subword-to-word transform matrices.
69
-
70
- Each word is tokenized individually (matching original behavior).
71
- The transform matrix averages subword representations back to
72
- word-level representations.
73
-
74
- sentences: list of lists of words (including [CLS]/[SEP])
75
- """
76
- all_data = []
77
- all_masks = []
78
- all_transforms = []
79
-
80
- for sentence in sentences:
81
- tok_ids = []
82
- input_mask = []
83
- transform = []
84
-
85
- # First pass: get subtokens for each word
86
- all_toks = []
87
- n = 0
88
- for word in sentence:
89
- toks = _word_to_subtokens(tokenizer, word)
90
- all_toks.append(toks)
91
- n += len(toks)
92
-
93
- # Second pass: build transform matrix and collect IDs
94
- cur = 0
95
- for idx, word in enumerate(sentence):
96
- toks = all_toks[idx]
97
- ind = list(np.zeros(n))
98
- for j in range(cur, cur + len(toks)):
99
- ind[j] = 1.0 / len(toks)
100
- cur += len(toks)
101
- transform.append(ind)
102
- tok_ids.extend(tokenizer.convert_tokens_to_ids(toks))
103
- input_mask.extend(np.ones(len(toks)))
104
-
105
- all_data.append(tok_ids)
106
- all_masks.append(input_mask)
107
- all_transforms.append(transform)
108
-
109
- lengths = np.array([len(l) for l in all_data])
110
- ordering = np.argsort(lengths)
111
-
112
- ordered_data = [None] * len(all_data)
113
- ordered_masks = [None] * len(all_data)
114
- ordered_transforms = [None] * len(all_data)
115
-
116
- for i, ind in enumerate(ordering):
117
- ordered_data[i] = all_data[ind]
118
- ordered_masks[i] = all_masks[ind]
119
- ordered_transforms[i] = all_transforms[ind]
120
-
121
- batched_data = []
122
- batched_mask = []
123
- batched_transforms = []
124
-
125
- i = 0
126
- current_batch = max_batch
127
-
128
- while i < len(ordered_data):
129
- batch_data = ordered_data[i:i + current_batch]
130
- batch_mask = ordered_masks[i:i + current_batch]
131
- batch_transforms = ordered_transforms[i:i + current_batch]
132
-
133
- ml = max(len(s) for s in batch_data)
134
- max_words = max(len(t) for t in batch_transforms)
135
-
136
- for j in range(len(batch_data)):
137
- blen = len(batch_data[j])
138
- for _k in range(blen, ml):
139
- batch_data[j].append(0)
140
- batch_mask[j].append(0)
141
- for z in range(len(batch_transforms[j])):
142
- batch_transforms[j][z].append(0)
143
- for _k in range(len(batch_transforms[j]), max_words):
144
- batch_transforms[j].append(np.zeros(ml))
145
-
146
- batched_data.append(torch.LongTensor(batch_data))
147
- batched_mask.append(torch.FloatTensor(batch_mask))
148
- batched_transforms.append(torch.FloatTensor(batch_transforms))
149
-
150
- i += current_batch
151
- if ml > 100:
152
- current_batch = 12
153
- if ml > 200:
154
- current_batch = 6
155
-
156
- return batched_data, batched_mask, batched_transforms, ordering
157
-
158
-
159
  MAX_SEQ_LEN = 512
160
 
161
 
162
  def _get_word_embeddings(tokenizer, model, sentences, device):
163
- """Get word-level BERT embeddings for a list of sentences.
164
-
165
- Returns list of sentences, each a list of (word, embedding) tuples.
166
- Mirrors the original LatinBERT.get_berts() method.
167
- Sentences whose subword length exceeds MAX_SEQ_LEN are skipped
168
- (returned as empty lists).
169
- """
170
- # Filter out sentences that exceed BERT's max sequence length
171
  valid_indices = []
172
  valid_sentences = []
173
  for idx, sent in enumerate(sentences):
174
  n_subtokens = sum(
175
- len(_word_to_subtokens(tokenizer, w)) for w in sent
176
  )
177
  if n_subtokens <= MAX_SEQ_LEN:
178
  valid_indices.append(idx)
179
  valid_sentences.append(sent)
180
 
181
- # Initialize results with empty lists for all sentences
182
  all_bert_sents = [[] for _ in sentences]
183
 
184
  if not valid_sentences:
185
  return all_bert_sents
186
 
187
- batched_data, batched_mask, batched_transforms, ordering = _get_batches(
188
- tokenizer, valid_sentences, BATCH_SIZE
189
  )
190
 
191
  ordered_preds = []
@@ -206,12 +93,10 @@ def _get_word_embeddings(tokenizer, model, sentences, device):
206
  for row in range(b_size):
207
  ordered_preds.append([np.array(r) for r in out[row]])
208
 
209
- # Restore original ordering within valid sentences
210
  preds_in_order = [None] * len(valid_sentences)
211
  for i, ind in enumerate(ordering):
212
  preds_in_order[ind] = ordered_preds[i]
213
 
214
- # Build (word, embedding) pairs and place back at original indices
215
  for vi, orig_idx in enumerate(valid_indices):
216
  sentence = valid_sentences[vi]
217
  bert_sent = []
@@ -226,13 +111,7 @@ def _get_word_embeddings(tokenizer, model, sentences, device):
226
 
227
 
228
  def test_embedding_parity(model_path):
229
- """Verify our HF tokenizer produces identical word-level embeddings.
230
-
231
- Feeds short sentences through the HF pipeline and checks that
232
- word-level embeddings (after subword averaging via transform matrix)
233
- have cosine similarity > 0.9999 with themselves when computed via
234
- two independent forward passes with the same tokenization.
235
- """
236
  device = _get_device()
237
 
238
  tokenizer = AutoTokenizer.from_pretrained(
@@ -248,16 +127,13 @@ def test_embedding_parity(model_path):
248
  "omnia vincit amor",
249
  ]
250
 
251
- # Build word lists with [CLS]/[SEP], lowercased
252
  sentences = []
253
  for raw in test_sentences_raw:
254
  words = ["[CLS]"] + raw.lower().split() + ["[SEP]"]
255
  sentences.append(words)
256
 
257
- # Get embeddings via our HF pipeline
258
  bert_sents = _get_word_embeddings(tokenizer, model, sentences, device)
259
 
260
- # Verify we get embeddings for all words
261
  for sent_idx, (raw, bert_sent) in enumerate(
262
  zip(test_sentences_raw, bert_sents)
263
  ):
@@ -271,10 +147,8 @@ def test_embedding_parity(model_path):
271
  assert emb.shape == (BERT_DIM,), (
272
  f"Expected ({BERT_DIM},), got {emb.shape}"
273
  )
274
- # Embedding should not be all zeros
275
  assert LA.norm(emb) > 0.1, f"Zero embedding for '{word}'"
276
 
277
- # Run a second forward pass and verify cosine similarity ≈ 1.0
278
  bert_sents_2 = _get_word_embeddings(tokenizer, model, sentences, device)
279
 
280
  for sent_idx in range(len(sentences)):
@@ -288,9 +162,6 @@ def test_embedding_parity(model_path):
288
  f"{cos:.6f} (expected > 0.9999)"
289
  )
290
 
291
- # Verify the transform matrix produces different embeddings for the
292
- # same word in different contexts (contextual, not static)
293
- # "in" appears in sentence 1 ("gallia est omnis divisa in partes tres")
294
  in_emb = None
295
  for word, emb in bert_sents[1]:
296
  if word == "in":
@@ -298,7 +169,6 @@ def test_embedding_parity(model_path):
298
  break
299
  assert in_emb is not None, "'in' not found in sentence 1"
300
 
301
- # "omnia" from sentence 2 should have a different embedding than "in"
302
  omnia_emb = None
303
  for word, emb in bert_sents[2]:
304
  if word == "omnia":
@@ -330,10 +200,7 @@ def test_embedding_parity(model_path):
330
 
331
 
332
  def _read_file_cltk(filename):
333
- """Read a text file and tokenize with CLTK, matching original pipeline.
334
-
335
- Returns list of sentences, each a list of words with [CLS]/[SEP].
336
- """
337
  from cltk.tokenizers.lat.lat import (
338
  LatinWordTokenizer as WordTokenizer,
339
  LatinPunktSentenceTokenizer as SentenceTokenizer,
@@ -364,12 +231,11 @@ def _download_corpus():
364
  import subprocess
365
 
366
  if CORPUS_TEXT_DIR.exists() and any(CORPUS_TEXT_DIR.iterdir()):
367
- return # Already downloaded
368
 
369
  DATA_DIR.mkdir(parents=True, exist_ok=True)
370
 
371
  if not CORPUS_ARCHIVE.exists():
372
- # Download via gdown (handles Google Drive large files)
373
  subprocess.run(
374
  ["pip", "install", "-q", "gdown"],
375
  check=True, capture_output=True,
@@ -383,7 +249,6 @@ def _download_corpus():
383
  check=True,
384
  )
385
 
386
- # Extract
387
  with tarfile.open(CORPUS_ARCHIVE, "r:gz") as tar:
388
  tar.extractall(path=DATA_DIR)
389
 
@@ -395,13 +260,7 @@ def _download_corpus():
395
  def _generate_embeddings_for_file(
396
  tokenizer, model, input_file, output_file, device
397
  ):
398
- """Generate BERT embeddings for a single text file.
399
-
400
- Reads the file with CLTK tokenization, computes word-level embeddings,
401
- and writes them in the original format:
402
- word\\tspace-separated 768 floats
403
- (blank line between sentences)
404
- """
405
  sents = _read_file_cltk(input_file)
406
  if not sents:
407
  return 0
@@ -413,7 +272,7 @@ def _generate_embeddings_for_file(
413
  with open(output_file, "w", encoding="utf-8") as out:
414
  for bert_sent in bert_sents:
415
  if not bert_sent:
416
- continue # skipped (too long)
417
  for word, emb in bert_sent:
418
  out.write(
419
  "%s\t%s\n" % (word, " ".join("%.5f" % x for x in emb))
@@ -426,11 +285,7 @@ def _generate_embeddings_for_file(
426
 
427
  @pytest.mark.slow
428
  def test_generate_embeddings(model_path):
429
- """Generate BERT embeddings for the Latin Library corpus.
430
-
431
- Downloads the corpus if needed, then processes each text file
432
- through the model, saving word-level embeddings to disk.
433
- """
434
  device = _get_device()
435
 
436
  tokenizer = AutoTokenizer.from_pretrained(
@@ -474,11 +329,7 @@ def test_generate_embeddings(model_path):
474
 
475
 
476
  def _load_embedding_file(filename):
477
- """Load pre-generated embeddings from a TSV file.
478
-
479
- Returns (matrix, sents, sent_ids, toks, position_in_sent).
480
- Mirrors the original proc_doc().
481
- """
482
  berts = []
483
  toks = []
484
  sent_ids = []
@@ -518,13 +369,45 @@ def _load_embedding_file(filename):
518
  return matrix, sents, sent_ids, toks, position_in_sent
519
 
520
 
521
- def _load_all_embeddings(bert_dir):
522
- """Load all embedding files from a directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
- Uses joblib for parallel loading. Returns the same structure as
525
- the original proc() function.
526
- """
527
- from joblib import Parallel, delayed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  files = sorted(
530
  str(f)
@@ -533,96 +416,50 @@ def _load_all_embeddings(bert_dir):
533
  )
534
  assert len(files) > 0, f"No embedding files found in {bert_dir}"
535
 
536
- print(f" Loading {len(files)} embedding files...")
 
 
537
 
538
- results = Parallel(n_jobs=min(10, len(files)))(
539
- delayed(_load_embedding_file)(f) for f in files
540
- )
541
-
542
- matrix_all = []
543
- sents_all = []
544
- sent_ids_all = []
545
- toks_all = []
546
- position_in_sent_all = []
547
- doc_ids = []
548
 
549
- for (matrix, sents, sent_ids, toks, pos), filename in zip(results, files):
550
- matrix_all.append(matrix)
551
- sents_all.append(sents)
552
- sent_ids_all.append(sent_ids)
553
- toks_all.append(toks)
554
- position_in_sent_all.append(pos)
555
- doc_ids.append(filename)
556
 
557
- return matrix_all, sents_all, sent_ids_all, toks_all, position_in_sent_all, doc_ids
 
 
 
 
 
 
 
 
 
 
 
558
 
 
 
 
 
559
 
560
- def _query_nearest_neighbors(
561
- target_bert, matrix_all, sents_all, sent_ids_all, toks_all,
562
- position_in_sent_all, doc_ids, top_n=25
563
- ):
564
- """Find the top-N contextually similar tokens across the corpus.
565
 
566
- Returns list of (cosine_score, context_window, doc_id) tuples.
567
- """
568
- all_vals = []
569
 
570
- for idx in range(len(doc_ids)):
571
- c_matrix = matrix_all[idx]
572
- c_sents = sents_all[idx]
573
- c_sent_ids = sent_ids_all[idx]
574
- c_toks = toks_all[idx]
575
- c_pos = position_in_sent_all[idx]
576
 
577
- if len(c_matrix) == 0:
578
- continue
579
-
580
- similarity = np.dot(c_matrix, target_bert)
581
- argsort = np.argsort(-similarity)
582
- len_s = len(similarity)
583
-
584
- for i in range(min(100, len_s)):
585
- tid = argsort[i]
586
- if (tid < len(c_sent_ids) and tid < len(c_pos)
587
- and c_sent_ids[tid] < len(c_sents)):
588
- pos = c_pos[tid]
589
- sent = c_sents[c_sent_ids[tid]]
590
- # Build context window (5 words each side)
591
- start = max(0, pos - 5)
592
- end = min(len(sent), pos + 6)
593
- before = " ".join(sent[start:pos])
594
- target = sent[pos]
595
- after = " ".join(sent[pos + 1:end])
596
- context = f"{before} **{target}** {after}".strip()
597
- all_vals.append((
598
- float(similarity[tid]),
599
- context,
600
- doc_ids[idx],
601
- target,
602
- ))
603
-
604
- all_vals.sort(key=lambda x: x[0], reverse=True)
605
- return all_vals[:top_n]
606
-
607
-
608
- # Queries from the paper's README
609
  QUERIES = [
610
  ("in", "gallia est omnis divisa in partes tres"),
611
  ("amor", "omnia vincit amor"),
 
612
  ]
613
 
614
 
615
  @pytest.mark.slow
616
  def test_contextual_nn_queries(model_path):
617
- """Run contextual nearest neighbor queries from the paper.
618
-
619
- Loads pre-generated embeddings, encodes query sentences, and finds
620
- the most contextually similar tokens across the corpus.
621
-
622
- Soft assertions:
623
- - Query word in its own sentence appears with cosine > 0.8
624
- - At least 10 of top-25 results contain the query word
625
- """
626
  device = _get_device()
627
 
628
  assert CORPUS_BERT_DIR.exists(), (
@@ -637,23 +474,16 @@ def test_contextual_nn_queries(model_path):
637
  model.to(device)
638
  model.eval()
639
 
640
- # Load all pre-generated embeddings
641
- corpus = _load_all_embeddings(CORPUS_BERT_DIR)
642
- (matrix_all, sents_all, sent_ids_all, toks_all,
643
- position_in_sent_all, doc_ids) = corpus
644
-
645
  for query_word, query_sent in QUERIES:
646
  print(f"\n{'=' * 60}")
647
  print(f"Query: '{query_word}' in '{query_sent}'")
648
  print("=" * 60)
649
 
650
- # Encode query sentence
651
  words = ["[CLS]"] + query_sent.lower().split() + ["[SEP]"]
652
  bert_sent = _get_word_embeddings(
653
  tokenizer, model, [words], device
654
  )[0]
655
 
656
- # Find the target word's embedding
657
  target_emb = None
658
  for word, emb in bert_sent:
659
  if word == query_word:
@@ -663,30 +493,24 @@ def test_contextual_nn_queries(model_path):
663
  f"Query word '{query_word}' not found in sentence"
664
  )
665
 
666
- # L2-normalize
667
  target_emb = target_emb / LA.norm(target_emb)
668
 
669
- # Find nearest neighbors
670
- results = _query_nearest_neighbors(
671
- target_emb, matrix_all, sents_all, sent_ids_all, toks_all,
672
- position_in_sent_all, doc_ids, top_n=25
673
  )
674
 
675
- # Print results
676
  for rank, (score, context, doc, matched_word) in enumerate(results):
677
  doc_short = Path(doc).stem
678
  print(f" {rank + 1:2d}. {score:.3f} {context} [{doc_short}]")
679
 
680
- # Soft assertions
681
- # 1. Query word in its own context should appear with cosine > 0.8
682
  self_hits = [
683
- r for r in results if r[3] == query_word and r[0] > 0.8
684
  ]
685
  assert len(self_hits) > 0, (
686
- f"Expected '{query_word}' to appear in top-25 with cosine > 0.8"
687
  )
688
 
689
- # 2. At least 10 of top-25 should contain the query word
690
  word_hits = [r for r in results if r[3] == query_word]
691
  assert len(word_hits) >= 10, (
692
  f"Expected at least 10 of top-25 to be '{query_word}', "
@@ -694,4 +518,4 @@ def test_contextual_nn_queries(model_path):
694
  )
695
 
696
  print(f"\n Soft checks passed: {len(self_hits)} self-hits with "
697
- f"cosine > 0.8, {len(word_hits)}/25 contain '{query_word}'")
 
25
  from torch import nn
26
  from transformers import AutoTokenizer, BertModel
27
 
28
+ from case_study_utils import (
29
+ BATCH_SIZE,
30
+ BERT_DIM,
31
+ DATA_DIR,
32
+ get_batches,
33
+ word_to_subtokens,
34
+ )
35
 
36
 
37
  def _get_device():
 
42
  return torch.device("mps")
43
  return torch.device("cpu")
44
 
 
 
45
 
46
  # Data paths
 
47
  CORPUS_TEXT_DIR = DATA_DIR / "latin_library_text"
48
  CORPUS_BERT_DIR = DATA_DIR / "latin_library_bert"
49
  CORPUS_ARCHIVE = DATA_DIR / "latin_library_text.tar.gz"
 
51
  # Google Drive download URL for Latin Library texts
52
  CORPUS_DOWNLOAD_ID = "1GRe3eFmQBDdF1kIT9T75aPTdquaf8Z8s"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  MAX_SEQ_LEN = 512
55
 
56
 
57
  def _get_word_embeddings(tokenizer, model, sentences, device):
58
+ """Get word-level BERT embeddings for a list of sentences."""
 
 
 
 
 
 
 
59
  valid_indices = []
60
  valid_sentences = []
61
  for idx, sent in enumerate(sentences):
62
  n_subtokens = sum(
63
+ len(word_to_subtokens(tokenizer, w)) for w in sent
64
  )
65
  if n_subtokens <= MAX_SEQ_LEN:
66
  valid_indices.append(idx)
67
  valid_sentences.append(sent)
68
 
 
69
  all_bert_sents = [[] for _ in sentences]
70
 
71
  if not valid_sentences:
72
  return all_bert_sents
73
 
74
+ batched_data, batched_mask, batched_transforms, ordering = get_batches(
75
+ tokenizer, valid_sentences, BATCH_SIZE, has_labels=False
76
  )
77
 
78
  ordered_preds = []
 
93
  for row in range(b_size):
94
  ordered_preds.append([np.array(r) for r in out[row]])
95
 
 
96
  preds_in_order = [None] * len(valid_sentences)
97
  for i, ind in enumerate(ordering):
98
  preds_in_order[ind] = ordered_preds[i]
99
 
 
100
  for vi, orig_idx in enumerate(valid_indices):
101
  sentence = valid_sentences[vi]
102
  bert_sent = []
 
111
 
112
 
113
  def test_embedding_parity(model_path):
114
+ """Verify our HF tokenizer produces identical word-level embeddings."""
 
 
 
 
 
 
115
  device = _get_device()
116
 
117
  tokenizer = AutoTokenizer.from_pretrained(
 
127
  "omnia vincit amor",
128
  ]
129
 
 
130
  sentences = []
131
  for raw in test_sentences_raw:
132
  words = ["[CLS]"] + raw.lower().split() + ["[SEP]"]
133
  sentences.append(words)
134
 
 
135
  bert_sents = _get_word_embeddings(tokenizer, model, sentences, device)
136
 
 
137
  for sent_idx, (raw, bert_sent) in enumerate(
138
  zip(test_sentences_raw, bert_sents)
139
  ):
 
147
  assert emb.shape == (BERT_DIM,), (
148
  f"Expected ({BERT_DIM},), got {emb.shape}"
149
  )
 
150
  assert LA.norm(emb) > 0.1, f"Zero embedding for '{word}'"
151
 
 
152
  bert_sents_2 = _get_word_embeddings(tokenizer, model, sentences, device)
153
 
154
  for sent_idx in range(len(sentences)):
 
162
  f"{cos:.6f} (expected > 0.9999)"
163
  )
164
 
 
 
 
165
  in_emb = None
166
  for word, emb in bert_sents[1]:
167
  if word == "in":
 
169
  break
170
  assert in_emb is not None, "'in' not found in sentence 1"
171
 
 
172
  omnia_emb = None
173
  for word, emb in bert_sents[2]:
174
  if word == "omnia":
 
200
 
201
 
202
  def _read_file_cltk(filename):
203
+ """Read a text file and tokenize with CLTK, matching original pipeline."""
 
 
 
204
  from cltk.tokenizers.lat.lat import (
205
  LatinWordTokenizer as WordTokenizer,
206
  LatinPunktSentenceTokenizer as SentenceTokenizer,
 
231
  import subprocess
232
 
233
  if CORPUS_TEXT_DIR.exists() and any(CORPUS_TEXT_DIR.iterdir()):
234
+ return
235
 
236
  DATA_DIR.mkdir(parents=True, exist_ok=True)
237
 
238
  if not CORPUS_ARCHIVE.exists():
 
239
  subprocess.run(
240
  ["pip", "install", "-q", "gdown"],
241
  check=True, capture_output=True,
 
249
  check=True,
250
  )
251
 
 
252
  with tarfile.open(CORPUS_ARCHIVE, "r:gz") as tar:
253
  tar.extractall(path=DATA_DIR)
254
 
 
260
  def _generate_embeddings_for_file(
261
  tokenizer, model, input_file, output_file, device
262
  ):
263
+ """Generate BERT embeddings for a single text file."""
 
 
 
 
 
 
264
  sents = _read_file_cltk(input_file)
265
  if not sents:
266
  return 0
 
272
  with open(output_file, "w", encoding="utf-8") as out:
273
  for bert_sent in bert_sents:
274
  if not bert_sent:
275
+ continue
276
  for word, emb in bert_sent:
277
  out.write(
278
  "%s\t%s\n" % (word, " ".join("%.5f" % x for x in emb))
 
285
 
286
  @pytest.mark.slow
287
  def test_generate_embeddings(model_path):
288
+ """Generate BERT embeddings for the Latin Library corpus."""
 
 
 
 
289
  device = _get_device()
290
 
291
  tokenizer = AutoTokenizer.from_pretrained(
 
329
 
330
 
331
  def _load_embedding_file(filename):
332
+ """Load pre-generated embeddings from a TSV file."""
 
 
 
 
333
  berts = []
334
  toks = []
335
  sent_ids = []
 
369
  return matrix, sents, sent_ids, toks, position_in_sent
370
 
371
 
372
+ def _search_one_file(args):
373
+ """Search a single embedding file for top-N matches."""
374
+ filename, target_bert, top_n = args
375
+ matrix, sents, sent_ids, toks, position_in_sent = \
376
+ _load_embedding_file(filename)
377
+
378
+ if len(matrix) == 0:
379
+ return []
380
+
381
+ similarity = np.dot(matrix, target_bert)
382
+
383
+ n_candidates = min(top_n, len(similarity))
384
+ if n_candidates >= len(similarity):
385
+ top_indices = np.arange(len(similarity))
386
+ else:
387
+ top_indices = np.argpartition(-similarity, n_candidates)[:n_candidates]
388
 
389
+ results = []
390
+ for tid in top_indices:
391
+ score = float(similarity[tid])
392
+ if (tid < len(sent_ids) and tid < len(position_in_sent)
393
+ and sent_ids[tid] < len(sents)):
394
+ pos = position_in_sent[tid]
395
+ sent = sents[sent_ids[tid]]
396
+ start = max(0, pos - 5)
397
+ end = min(len(sent), pos + 6)
398
+ before = " ".join(sent[start:pos])
399
+ target_word = sent[pos]
400
+ after = " ".join(sent[pos + 1:end])
401
+ context = f"{before} **{target_word}** {after}".strip()
402
+ results.append((score, context, filename, target_word))
403
+
404
+ return results
405
+
406
+
407
+ def _query_streaming(target_bert, bert_dir, top_n=25):
408
+ """Find top-N contextually similar tokens by streaming through files."""
409
+ import heapq
410
+ import multiprocessing
411
 
412
  files = sorted(
413
  str(f)
 
416
  )
417
  assert len(files) > 0, f"No embedding files found in {bert_dir}"
418
 
419
+ n_workers = max(1, multiprocessing.cpu_count() - 1)
420
+ print(f" Searching {len(files)} files with {n_workers} workers...",
421
+ flush=True)
422
 
423
+ args_list = [(f, target_bert, top_n) for f in files]
 
 
 
 
 
 
 
 
 
424
 
425
+ heap = []
426
+ min_score = -float("inf")
427
+ files_done = 0
 
 
 
 
428
 
429
+ with multiprocessing.Pool(n_workers) as pool:
430
+ for file_results in pool.imap_unordered(_search_one_file, args_list,
431
+ chunksize=10):
432
+ for entry in file_results:
433
+ score = entry[0]
434
+ if len(heap) < top_n:
435
+ heapq.heappush(heap, entry)
436
+ if len(heap) == top_n:
437
+ min_score = heap[0][0]
438
+ elif score > min_score:
439
+ heapq.heapreplace(heap, entry)
440
+ min_score = heap[0][0]
441
 
442
+ files_done += 1
443
+ if files_done % 200 == 0:
444
+ print(f" Searched {files_done}/{len(files)} files...",
445
+ flush=True)
446
 
447
+ print(f" Searched {files_done}/{len(files)} files.", flush=True)
 
 
 
 
448
 
449
+ results = sorted(heap, key=lambda x: x[0], reverse=True)
450
+ return results
 
451
 
 
 
 
 
 
 
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  QUERIES = [
454
  ("in", "gallia est omnis divisa in partes tres"),
455
  ("amor", "omnia vincit amor"),
456
+ ("audentes", "audentes fortuna iuvat"),
457
  ]
458
 
459
 
460
  @pytest.mark.slow
461
  def test_contextual_nn_queries(model_path):
462
+ """Run contextual nearest neighbor queries from the paper."""
 
 
 
 
 
 
 
 
463
  device = _get_device()
464
 
465
  assert CORPUS_BERT_DIR.exists(), (
 
474
  model.to(device)
475
  model.eval()
476
 
 
 
 
 
 
477
  for query_word, query_sent in QUERIES:
478
  print(f"\n{'=' * 60}")
479
  print(f"Query: '{query_word}' in '{query_sent}'")
480
  print("=" * 60)
481
 
 
482
  words = ["[CLS]"] + query_sent.lower().split() + ["[SEP]"]
483
  bert_sent = _get_word_embeddings(
484
  tokenizer, model, [words], device
485
  )[0]
486
 
 
487
  target_emb = None
488
  for word, emb in bert_sent:
489
  if word == query_word:
 
493
  f"Query word '{query_word}' not found in sentence"
494
  )
495
 
 
496
  target_emb = target_emb / LA.norm(target_emb)
497
 
498
+ print(" Searching corpus (streaming)...")
499
+ results = _query_streaming(
500
+ target_emb, CORPUS_BERT_DIR, top_n=25
 
501
  )
502
 
 
503
  for rank, (score, context, doc, matched_word) in enumerate(results):
504
  doc_short = Path(doc).stem
505
  print(f" {rank + 1:2d}. {score:.3f} {context} [{doc_short}]")
506
 
 
 
507
  self_hits = [
508
+ r for r in results if r[3] == query_word and r[0] > 0.7
509
  ]
510
  assert len(self_hits) > 0, (
511
+ f"Expected '{query_word}' to appear in top-25 with cosine > 0.7"
512
  )
513
 
 
514
  word_hits = [r for r in results if r[3] == query_word]
515
  assert len(word_hits) >= 10, (
516
  f"Expected at least 10 of top-25 to be '{query_word}', "
 
518
  )
519
 
520
  print(f"\n Soft checks passed: {len(self_hits)} self-hits with "
521
+ f"cosine > 0.7, {len(word_hits)}/25 contain '{query_word}'")
tests/test_infilling.py CHANGED
@@ -13,28 +13,17 @@ Reference results (from original logs):
13
 
14
  import copy
15
  import re
16
- from pathlib import Path
17
  from typing import List
18
 
19
  import pytest
20
  import torch
21
  from transformers import AutoTokenizer, BertForMaskedLM
22
 
23
- DATA_PATH = (
24
- Path(__file__).parent.parent
25
- / ".claude/reference/latin-bert/case_studies/infilling/data/emendation_filtered.txt"
26
- )
27
 
28
  def _tokenize_text(tokenizer, text: str) -> List[int]:
29
- """Tokenize text word-by-word, matching the original LatinTokenizer behavior.
30
-
31
- The original uses cltk WordTokenizer to split into words, then lowercases
32
- each word and encodes it individually with the SubwordTextEncoder. Our HF
33
- tokenizer's encode() processes the entire string including spaces, which
34
- produces different (incorrect) results because spaces get escaped into
35
- subtoken sequences. Instead, we split on whitespace, lowercase each word,
36
- and encode individually.
37
- """
38
  ids = []
39
  for word in text.split():
40
  word_ids = tokenizer.encode(word.lower(), add_special_tokens=False)
@@ -42,7 +31,6 @@ def _tokenize_text(tokenizer, text: str) -> List[int]:
42
  return ids
43
 
44
 
45
- # Tolerance: allow +/- 1% from reference
46
  REF_P1 = 0.331
47
  REF_P10 = 0.622
48
  REF_P50 = 0.740
@@ -50,10 +38,7 @@ TOLERANCE = 0.01
50
 
51
 
52
  def _proc(model, tokenizer, token_ids, device):
53
- """Predict the subtoken at the [MASK] position for multi-subtoken words.
54
-
55
- Mirrors the original proc() which finds [MASK] by searching token_ids.
56
- """
57
  mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
58
  mask_pos = token_ids.index(mask_id)
59
  t = torch.LongTensor(token_ids).unsqueeze(0).to(device)
@@ -65,13 +50,7 @@ def _proc(model, tokenizer, token_ids, device):
65
 
66
 
67
  def _evaluate_one(model, tokenizer, text_before, text_after, truth, device):
68
- """Evaluate a single infilling example. Returns (p1, p10, p50).
69
-
70
- The original tokenizer lowercases each word before subword encoding.
71
- Our HF tokenizer does not lowercase, so we lowercase the text here
72
- to match the original behavior.
73
- """
74
- # Tokenize word-by-word with lowercasing, matching original behavior
75
  before_ids = _tokenize_text(tokenizer, text_before)
76
  after_ids = _tokenize_text(tokenizer, text_after)
77
  mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
@@ -94,10 +73,6 @@ def _evaluate_one(model, tokenizer, text_before, text_after, truth, device):
94
 
95
  suffix = ""
96
  if not predicted_token.endswith("_"):
97
- # Multi-subtoken: insert predicted subtoken before [MASK]
98
- # so the sequence becomes: ... predicted [MASK] ...
99
- # then predict the next subtoken at the new [MASK] position.
100
- # This mirrors the original predict_word.py behavior.
101
  uptokens = copy.deepcopy(token_ids)
102
  uptokens.insert(mask_pos, predicted_index)
103
  suffix = _proc(model, tokenizer, uptokens, device)
@@ -131,7 +106,7 @@ def test_infilling_precision(model_path):
131
  max_tokens = 100
132
  all_p1 = all_p10 = all_p50 = n = 0
133
 
134
- with open(DATA_PATH) as f:
135
  for line in f:
136
  cols = line.split("\t")
137
  if len(cols) < 5:
 
13
 
14
  import copy
15
  import re
 
16
  from typing import List
17
 
18
  import pytest
19
  import torch
20
  from transformers import AutoTokenizer, BertForMaskedLM
21
 
22
+ from case_study_utils import INFILLING_DATA_PATH
23
+
 
 
24
 
25
  def _tokenize_text(tokenizer, text: str) -> List[int]:
26
+ """Tokenize text word-by-word, matching the original LatinTokenizer behavior."""
 
 
 
 
 
 
 
 
27
  ids = []
28
  for word in text.split():
29
  word_ids = tokenizer.encode(word.lower(), add_special_tokens=False)
 
31
  return ids
32
 
33
 
 
34
  REF_P1 = 0.331
35
  REF_P10 = 0.622
36
  REF_P50 = 0.740
 
38
 
39
 
40
  def _proc(model, tokenizer, token_ids, device):
41
+ """Predict the subtoken at the [MASK] position for multi-subtoken words."""
 
 
 
42
  mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
43
  mask_pos = token_ids.index(mask_id)
44
  t = torch.LongTensor(token_ids).unsqueeze(0).to(device)
 
50
 
51
 
52
  def _evaluate_one(model, tokenizer, text_before, text_after, truth, device):
53
+ """Evaluate a single infilling example. Returns (p1, p10, p50)."""
 
 
 
 
 
 
54
  before_ids = _tokenize_text(tokenizer, text_before)
55
  after_ids = _tokenize_text(tokenizer, text_after)
56
  mask_id = tokenizer.convert_tokens_to_ids("[MASK]")
 
73
 
74
  suffix = ""
75
  if not predicted_token.endswith("_"):
 
 
 
 
76
  uptokens = copy.deepcopy(token_ids)
77
  uptokens.insert(mask_pos, predicted_index)
78
  suffix = _proc(model, tokenizer, uptokens, device)
 
106
  max_tokens = 100
107
  all_p1 = all_p10 = all_p50 = n = 0
108
 
109
+ with open(INFILLING_DATA_PATH) as f:
110
  for line in f:
111
  cols = line.split("\t")
112
  if len(cols) < 5:
tests/test_pos_tagging.py CHANGED
@@ -15,18 +15,19 @@ from pathlib import Path
15
  import numpy as np
16
  import pytest
17
  import torch
18
- from torch import nn
19
- from torch.nn import CrossEntropyLoss
20
  import torch.optim as optim
21
  from transformers import AutoTokenizer, BertModel
22
 
 
 
 
 
 
 
23
  torch.manual_seed(0)
24
  np.random.seed(0)
25
 
26
  TOLERANCE = 0.01
27
- BATCH_SIZE = 32
28
- DROPOUT_RATE = 0.25
29
- BERT_DIM = 768
30
 
31
  UD_REPOS = {
32
  "perseus": "https://github.com/UniversalDependencies/UD_Latin-Perseus.git",
@@ -40,17 +41,9 @@ REFERENCE_ACCURACY = {
40
  "ittb": 0.988,
41
  }
42
 
43
- # Special tokens that should not go through subword encoding
44
- _SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
45
-
46
 
47
  def _read_conllu_annotations(filename, tagset, labeled=True):
48
- """Read CoNLL-U file, return list of sentences.
49
-
50
- Each sentence is a list of [word, label, sentenceID, filename].
51
- Mirrors the original sequence_reader.read_annotations().
52
- Words are lowercased to match the original pipeline.
53
- """
54
  sentences = []
55
  sentence = [["[CLS]", -100, -1, filename]]
56
  sentence_id = 0
@@ -67,7 +60,7 @@ def _read_conllu_annotations(filename, tagset, labeled=True):
67
  else:
68
  cols = line.rstrip().split("\t")
69
  if "-" in cols[0] or "." in cols[0]:
70
- continue # skip multiword/empty tokens
71
  word = cols[1].lower()
72
  label = tagset[cols[3]] if labeled else 0
73
  sentence.append([word, label, sentence_id, filename])
@@ -93,166 +86,6 @@ def _generate_tagset(filenames):
93
  return {tag: idx for idx, tag in enumerate(tags)}
94
 
95
 
96
- def _word_to_subtokens(tokenizer, word):
97
- """Get subtoken strings for a single word.
98
-
99
- Special tokens ([CLS], [SEP], etc.) are returned as-is.
100
- Regular words are tokenized through the subword pipeline,
101
- matching the original LatinTokenizer.tokenize() behavior which
102
- processes one already-lowercased word at a time.
103
- """
104
- if word in _SPECIAL_TOKENS:
105
- return [word]
106
- return tokenizer.tokenize(word)
107
-
108
-
109
- class BertForSequenceLabeling(nn.Module):
110
- """BERT + linear classifier for sequence labeling.
111
-
112
- Ported from original latin_sequence_labeling.py, replacing
113
- tensor2tensor tokenizer with HF AutoTokenizer.
114
- """
115
-
116
- def __init__(self, tokenizer, model, freeze_bert=False, num_labels=2):
117
- super().__init__()
118
- self.tokenizer = tokenizer
119
- self.num_labels = num_labels
120
- self.bert = model
121
- self.bert.eval()
122
- if freeze_bert:
123
- for param in self.bert.parameters():
124
- param.requires_grad = False
125
- self.dropout = nn.Dropout(DROPOUT_RATE)
126
- self.classifier = nn.Linear(BERT_DIM, num_labels)
127
-
128
- def forward(self, input_ids, attention_mask=None, transforms=None,
129
- labels=None):
130
- device = input_ids.device
131
- if attention_mask is not None:
132
- attention_mask = attention_mask.to(device)
133
- if transforms is not None:
134
- transforms = transforms.to(device)
135
- if labels is not None:
136
- labels = labels.to(device)
137
-
138
- outputs = self.bert(input_ids, attention_mask=attention_mask)
139
- sequence_output = outputs[0]
140
- out = torch.matmul(transforms, sequence_output)
141
- logits = self.classifier(out)
142
-
143
- if labels is not None:
144
- loss_fct = CrossEntropyLoss(ignore_index=-100)
145
- return loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
146
- return logits
147
-
148
- def get_batches(self, sentences, max_batch):
149
- """Tokenize and batch sentences with subword-to-word transform
150
- matrices.
151
-
152
- Each word is tokenized individually (matching original behavior).
153
- Special tokens [CLS]/[SEP] produce a single token each.
154
- The transform matrix averages subword representations back to
155
- word-level representations.
156
- """
157
- all_data = []
158
- all_masks = []
159
- all_labels = []
160
- all_transforms = []
161
-
162
- for sentence in sentences:
163
- tok_ids = []
164
- input_mask = []
165
- labels = []
166
- transform = []
167
-
168
- # First pass: get subtokens for each word
169
- all_toks = []
170
- n = 0
171
- for word in sentence:
172
- toks = _word_to_subtokens(self.tokenizer, word[0])
173
- all_toks.append(toks)
174
- n += len(toks)
175
-
176
- # Second pass: build transform matrix and collect IDs
177
- cur = 0
178
- for idx, word in enumerate(sentence):
179
- toks = all_toks[idx]
180
- ind = list(np.zeros(n))
181
- for j in range(cur, cur + len(toks)):
182
- ind[j] = 1.0 / len(toks)
183
- cur += len(toks)
184
- transform.append(ind)
185
- tok_ids.extend(
186
- self.tokenizer.convert_tokens_to_ids(toks)
187
- )
188
- input_mask.extend(np.ones(len(toks)))
189
- labels.append(int(word[1]))
190
-
191
- all_data.append(tok_ids)
192
- all_masks.append(input_mask)
193
- all_labels.append(labels)
194
- all_transforms.append(transform)
195
-
196
- lengths = np.array([len(l) for l in all_data])
197
- ordering = np.argsort(lengths)
198
-
199
- ordered_data = [None] * len(all_data)
200
- ordered_masks = [None] * len(all_data)
201
- ordered_labels = [None] * len(all_data)
202
- ordered_transforms = [None] * len(all_data)
203
-
204
- for i, ind in enumerate(ordering):
205
- ordered_data[i] = all_data[ind]
206
- ordered_masks[i] = all_masks[ind]
207
- ordered_labels[i] = all_labels[ind]
208
- ordered_transforms[i] = all_transforms[ind]
209
-
210
- batched_data = []
211
- batched_mask = []
212
- batched_labels = []
213
- batched_transforms = []
214
-
215
- i = 0
216
- current_batch = max_batch
217
-
218
- while i < len(ordered_data):
219
- batch_data = ordered_data[i:i + current_batch]
220
- batch_mask = ordered_masks[i:i + current_batch]
221
- batch_labels = ordered_labels[i:i + current_batch]
222
- batch_transforms = ordered_transforms[i:i + current_batch]
223
-
224
- ml = max(len(s) for s in batch_data)
225
- mlabel = max(len(l) for l in batch_labels)
226
-
227
- for j in range(len(batch_data)):
228
- blen = len(batch_data[j])
229
- blab = len(batch_labels[j])
230
- for _k in range(blen, ml):
231
- batch_data[j].append(0)
232
- batch_mask[j].append(0)
233
- for z in range(len(batch_transforms[j])):
234
- batch_transforms[j][z].append(0)
235
- for _k in range(blab, mlabel):
236
- batch_labels[j].append(-100)
237
- for _k in range(len(batch_transforms[j]), mlabel):
238
- batch_transforms[j].append(np.zeros(ml))
239
-
240
- batched_data.append(torch.LongTensor(batch_data))
241
- batched_mask.append(torch.FloatTensor(batch_mask))
242
- batched_labels.append(torch.LongTensor(batch_labels))
243
- batched_transforms.append(torch.FloatTensor(batch_transforms))
244
-
245
- i += current_batch
246
- # Adjust batch size for longer sequences (original behavior)
247
- if ml > 100:
248
- current_batch = 12
249
- if ml > 200:
250
- current_batch = 6
251
-
252
- return (batched_data, batched_mask, batched_labels,
253
- batched_transforms, ordering)
254
-
255
-
256
  def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
257
  """Train POS tagger on a UD treebank and return test accuracy."""
258
  tokenizer = AutoTokenizer.from_pretrained(
@@ -260,36 +93,32 @@ def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
260
  )
261
  bert_model = BertModel.from_pretrained(model_path)
262
 
263
- # Find CoNLL-U files
264
  conllu_files = sorted(Path(treebank_dir).glob("*.conllu"))
265
  train_file = [f for f in conllu_files if "train" in f.name][0]
266
  test_file = [f for f in conllu_files if "test" in f.name][0]
267
  dev_files = [f for f in conllu_files if "dev" in f.name]
268
 
269
- # Generate tagset from all files
270
  tagset = _generate_tagset([str(f) for f in conllu_files])
271
  num_labels = len(tagset)
272
 
273
  model = BertForSequenceLabeling(
274
- tokenizer, bert_model, freeze_bert=False, num_labels=num_labels
 
275
  )
276
  model.to(device)
277
 
278
- # Prepare training data
279
  train_sents = _read_conllu_annotations(str(train_file), tagset)
280
- batched = model.get_batches(train_sents, BATCH_SIZE)
281
- train_data, train_mask, train_labels, train_transforms, _ = batched
282
 
283
- # Prepare test data
284
  test_sents = _read_conllu_annotations(str(test_file), tagset)
285
- test_batched = model.get_batches(test_sents, BATCH_SIZE)
286
- test_data, test_mask, test_labels, test_transforms, _ = test_batched
287
 
288
- # Prepare dev data (if available)
289
  if dev_files:
290
  dev_sents = _read_conllu_annotations(str(dev_files[0]), tagset)
291
- dev_batched = model.get_batches(dev_sents, BATCH_SIZE)
292
- dev_data, dev_mask, dev_labels, dev_transforms, _ = dev_batched
293
  else:
294
  dev_data = None
295
 
@@ -298,7 +127,7 @@ def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
298
  best_state = None
299
  best_epoch = 0
300
 
301
- for epoch in range(5): # 5 epochs, matching original run_bert_eval.sh
302
  model.train()
303
  big_loss = 0
304
  for b in range(len(train_data)):
@@ -315,7 +144,6 @@ def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
315
 
316
  print(f" epoch {epoch}: loss={big_loss:.2f}")
317
 
318
- # Evaluate on dev (if available) to pick best epoch
319
  if dev_data is not None:
320
  model.eval()
321
  cor = tot = 0
@@ -345,7 +173,6 @@ def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
345
  }
346
  best_epoch = epoch
347
  else:
348
- # No dev set (Perseus): save last epoch
349
  best_state = {
350
  k: v.cpu().clone()
351
  for k, v in model.state_dict().items()
@@ -354,11 +181,9 @@ def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
354
 
355
  print(f" best epoch: {best_epoch}")
356
 
357
- # Load best model
358
  if best_state is not None:
359
  model.load_state_dict(best_state)
360
 
361
- # Evaluate on test
362
  model.eval()
363
  cor = tot = 0
364
  with torch.no_grad():
 
15
  import numpy as np
16
  import pytest
17
  import torch
 
 
18
  import torch.optim as optim
19
  from transformers import AutoTokenizer, BertModel
20
 
21
+ from case_study_utils import (
22
+ BATCH_SIZE,
23
+ BERT_DIM,
24
+ BertForSequenceLabeling,
25
+ )
26
+
27
  torch.manual_seed(0)
28
  np.random.seed(0)
29
 
30
  TOLERANCE = 0.01
 
 
 
31
 
32
  UD_REPOS = {
33
  "perseus": "https://github.com/UniversalDependencies/UD_Latin-Perseus.git",
 
41
  "ittb": 0.988,
42
  }
43
 
 
 
 
44
 
45
  def _read_conllu_annotations(filename, tagset, labeled=True):
46
+ """Read CoNLL-U file, return list of sentences."""
 
 
 
 
 
47
  sentences = []
48
  sentence = [["[CLS]", -100, -1, filename]]
49
  sentence_id = 0
 
60
  else:
61
  cols = line.rstrip().split("\t")
62
  if "-" in cols[0] or "." in cols[0]:
63
+ continue
64
  word = cols[1].lower()
65
  label = tagset[cols[3]] if labeled else 0
66
  sentence.append([word, label, sentence_id, filename])
 
86
  return {tag: idx for idx, tag in enumerate(tags)}
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def _train_and_evaluate(treebank_name, treebank_dir, device, model_path):
90
  """Train POS tagger on a UD treebank and return test accuracy."""
91
  tokenizer = AutoTokenizer.from_pretrained(
 
93
  )
94
  bert_model = BertModel.from_pretrained(model_path)
95
 
 
96
  conllu_files = sorted(Path(treebank_dir).glob("*.conllu"))
97
  train_file = [f for f in conllu_files if "train" in f.name][0]
98
  test_file = [f for f in conllu_files if "test" in f.name][0]
99
  dev_files = [f for f in conllu_files if "dev" in f.name]
100
 
 
101
  tagset = _generate_tagset([str(f) for f in conllu_files])
102
  num_labels = len(tagset)
103
 
104
  model = BertForSequenceLabeling(
105
+ tokenizer, bert_model, freeze_bert=False, num_labels=num_labels,
106
+ hidden_size=BERT_DIM
107
  )
108
  model.to(device)
109
 
 
110
  train_sents = _read_conllu_annotations(str(train_file), tagset)
111
+ train_data, train_mask, train_labels, train_transforms, _ = \
112
+ model.get_batches(train_sents, BATCH_SIZE)
113
 
 
114
  test_sents = _read_conllu_annotations(str(test_file), tagset)
115
+ test_data, test_mask, test_labels, test_transforms, _ = \
116
+ model.get_batches(test_sents, BATCH_SIZE)
117
 
 
118
  if dev_files:
119
  dev_sents = _read_conllu_annotations(str(dev_files[0]), tagset)
120
+ dev_data, dev_mask, dev_labels, dev_transforms, _ = \
121
+ model.get_batches(dev_sents, BATCH_SIZE)
122
  else:
123
  dev_data = None
124
 
 
127
  best_state = None
128
  best_epoch = 0
129
 
130
+ for epoch in range(5):
131
  model.train()
132
  big_loss = 0
133
  for b in range(len(train_data)):
 
144
 
145
  print(f" epoch {epoch}: loss={big_loss:.2f}")
146
 
 
147
  if dev_data is not None:
148
  model.eval()
149
  cor = tot = 0
 
173
  }
174
  best_epoch = epoch
175
  else:
 
176
  best_state = {
177
  k: v.cpu().clone()
178
  for k, v in model.state_dict().items()
 
181
 
182
  print(f" best epoch: {best_epoch}")
183
 
 
184
  if best_state is not None:
185
  model.load_state_dict(best_state)
186
 
 
187
  model.eval()
188
  cor = tot = 0
189
  with torch.no_grad():
tests/test_wsd.py CHANGED
@@ -9,174 +9,27 @@ Reference results (from original logs):
9
  """
10
 
11
  import random
12
- from pathlib import Path
13
 
14
  import numpy as np
15
  import pytest
16
  import torch
17
- from torch import nn
18
- from torch.nn import CrossEntropyLoss
19
  import torch.optim as optim
20
  from transformers import AutoTokenizer, BertModel
21
 
 
 
 
 
 
 
22
  random.seed(1)
23
  torch.manual_seed(0)
24
  np.random.seed(0)
25
 
26
- DATA_PATH = (
27
- Path(__file__).parent.parent
28
- / ".claude/reference/latin-bert/case_studies/wsd/data/latin.sense.data"
29
- )
30
-
31
  REF_ACCURACY = 0.754
32
  TOLERANCE = 0.02 # WSD has more variance due to per-lemma training
33
- BATCH_SIZE = 32
34
- DROPOUT_RATE = 0.25
35
- BERT_DIM = 768
36
  MAX_EPOCHS = 100
37
 
38
- # Special tokens that should not go through subword encoding
39
- _SPECIAL_TOKENS = {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}
40
-
41
-
42
- def _word_to_subtokens(tokenizer, word):
43
- """Get subtoken strings for a single word.
44
-
45
- Special tokens ([CLS], [SEP], etc.) are returned as-is.
46
- Regular words are lowercased and tokenized through the subword pipeline,
47
- matching the original LatinTokenizer.tokenize() behavior.
48
- """
49
- if word in _SPECIAL_TOKENS:
50
- return [word]
51
- return tokenizer.tokenize(word.lower())
52
-
53
-
54
- class BertForSequenceLabeling(nn.Module):
55
- """BERT + linear classifier for sequence labeling (binary WSD)."""
56
-
57
- def __init__(self, tokenizer, bert_model, freeze_bert=False,
58
- num_labels=2):
59
- super().__init__()
60
- self.tokenizer = tokenizer
61
- self.num_labels = num_labels
62
- self.bert = bert_model
63
- self.bert.eval()
64
- if freeze_bert:
65
- for param in self.bert.parameters():
66
- param.requires_grad = False
67
- self.dropout = nn.Dropout(DROPOUT_RATE)
68
- self.classifier = nn.Linear(BERT_DIM, num_labels)
69
-
70
- def forward(self, input_ids, attention_mask=None, transforms=None,
71
- labels=None):
72
- device = input_ids.device
73
- if attention_mask is not None:
74
- attention_mask = attention_mask.to(device)
75
- if transforms is not None:
76
- transforms = transforms.to(device)
77
- if labels is not None:
78
- labels = labels.to(device)
79
-
80
- outputs = self.bert(input_ids, attention_mask=attention_mask)
81
- sequence_output = outputs[0]
82
- out = torch.matmul(transforms, sequence_output)
83
- logits = self.classifier(out)
84
-
85
- if labels is not None:
86
- loss_fct = CrossEntropyLoss(ignore_index=-100)
87
- return loss_fct(
88
- logits.view(-1, self.num_labels), labels.view(-1)
89
- )
90
- return logits
91
-
92
- def get_batches(self, sentences, max_batch):
93
- """Tokenize and batch with subword-to-word transform matrices."""
94
- all_data, all_masks, all_labels, all_transforms = [], [], [], []
95
-
96
- for sentence in sentences:
97
- tok_ids, input_mask, labels, transform = [], [], [], []
98
- all_toks = []
99
- n = 0
100
- for word in sentence:
101
- toks = _word_to_subtokens(self.tokenizer, word[0])
102
- all_toks.append(toks)
103
- n += len(toks)
104
-
105
- cur = 0
106
- for idx, word in enumerate(sentence):
107
- toks = all_toks[idx]
108
- ind = list(np.zeros(n))
109
- for j in range(cur, cur + len(toks)):
110
- ind[j] = 1.0 / len(toks)
111
- cur += len(toks)
112
- transform.append(ind)
113
- tok_ids.extend(
114
- self.tokenizer.convert_tokens_to_ids(toks)
115
- )
116
- input_mask.extend(np.ones(len(toks)))
117
- labels.append(int(word[1]))
118
-
119
- all_data.append(tok_ids)
120
- all_masks.append(input_mask)
121
- all_labels.append(labels)
122
- all_transforms.append(transform)
123
-
124
- lengths = np.array([len(l) for l in all_data])
125
- ordering = np.argsort(lengths)
126
-
127
- ordered_data = [None] * len(all_data)
128
- ordered_masks = [None] * len(all_data)
129
- ordered_labels = [None] * len(all_data)
130
- ordered_transforms = [None] * len(all_data)
131
-
132
- for i, ind in enumerate(ordering):
133
- ordered_data[i] = all_data[ind]
134
- ordered_masks[i] = all_masks[ind]
135
- ordered_labels[i] = all_labels[ind]
136
- ordered_transforms[i] = all_transforms[ind]
137
-
138
- batched_data = []
139
- batched_mask = []
140
- batched_labels = []
141
- batched_transforms = []
142
-
143
- i = 0
144
- current_batch = max_batch
145
-
146
- while i < len(ordered_data):
147
- bd = ordered_data[i:i + current_batch]
148
- bm = ordered_masks[i:i + current_batch]
149
- bl = ordered_labels[i:i + current_batch]
150
- bt = ordered_transforms[i:i + current_batch]
151
-
152
- ml = max(len(s) for s in bd)
153
- mlabel = max(len(l) for l in bl)
154
-
155
- for j in range(len(bd)):
156
- for _k in range(len(bd[j]), ml):
157
- bd[j].append(0)
158
- bm[j].append(0)
159
- for z in range(len(bt[j])):
160
- bt[j][z].append(0)
161
- for _k in range(len(bl[j]), mlabel):
162
- bl[j].append(-100)
163
- for _k in range(len(bt[j]), mlabel):
164
- bt[j].append(np.zeros(ml))
165
-
166
- batched_data.append(torch.LongTensor(bd))
167
- batched_mask.append(torch.FloatTensor(bm))
168
- batched_labels.append(torch.LongTensor(bl))
169
- batched_transforms.append(torch.FloatTensor(bt))
170
-
171
- i += current_batch
172
- if ml > 100:
173
- current_batch = 12
174
- if ml > 200:
175
- current_batch = 6
176
-
177
- return (batched_data, batched_mask, batched_labels,
178
- batched_transforms, ordering)
179
-
180
 
181
  def _get_labs(before, target, after, label):
182
  """Build a labeled sentence for WSD.
@@ -186,7 +39,7 @@ def _get_labs(before, target, after, label):
186
  """
187
  sent = []
188
  for word in before.split(" "):
189
- if word: # skip empty strings from split on empty/whitespace
190
  sent.append((word, -100))
191
  sent.append((target, label))
192
  for word in after.split(" "):
@@ -220,11 +73,7 @@ def _read_wsd_data(filename):
220
 
221
 
222
  def _get_splits(data):
223
- """10-fold cross-validation splits.
224
-
225
- For each sense (0 and 1), examples are assigned to folds by index.
226
- testFold = idx % 10, devFold = testFold - 1 (wrapping to 9).
227
- """
228
  trains, tests, devs = [], [], []
229
  for _i in range(10):
230
  trains.append([])
@@ -253,11 +102,7 @@ def _get_splits(data):
253
 
254
  def _evaluate(model, batched_data, batched_mask, batched_labels,
255
  batched_transforms, device):
256
- """Evaluate model on batched data, return (correct, total).
257
-
258
- Mirrors the original evaluate() method which returns (cor, tot),
259
- with accumulation happening outside this function.
260
- """
261
  model.eval()
262
  cor = 0
263
  tot = 0
@@ -283,19 +128,13 @@ def _evaluate(model, batched_data, batched_mask, batched_labels,
283
 
284
  @pytest.mark.slow
285
  def test_wsd_accuracy(model_path):
286
- """Reproduce WSD case study from Bamman & Burns (2020).
287
-
288
- Trains a separate binary classifier per lemma (201 lemmas) with
289
- 10-fold cross-validation. Uses fold 0 splits (train/dev/test).
290
- Accumulates dev and test correct/total across all lemmas at each
291
- epoch, then picks the best dev epoch and reports test accuracy.
292
- """
293
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
 
295
  tokenizer = AutoTokenizer.from_pretrained(
296
  model_path, trust_remote_code=True
297
  )
298
- data = _read_wsd_data(str(DATA_PATH))
299
 
300
  dev_cors = [0.0] * MAX_EPOCHS
301
  test_cors = [0.0] * MAX_EPOCHS
@@ -305,7 +144,6 @@ def test_wsd_accuracy(model_path):
305
  for lemma_idx, lemma in enumerate(data):
306
  print(f"\n[{lemma_idx + 1}/{len(data)}] {lemma}")
307
 
308
- # Fresh model per lemma
309
  bert_model = BertModel.from_pretrained(model_path)
310
  model = BertForSequenceLabeling(
311
  tokenizer, bert_model, freeze_bert=False, num_labels=2
@@ -313,18 +151,15 @@ def test_wsd_accuracy(model_path):
313
  model.to(device)
314
 
315
  trains, devs, tests = _get_splits(data[lemma])
316
- train_data = trains[0]
317
- dev_data = devs[0]
318
- test_data = tests[0]
319
 
320
  train_b, train_m, train_l, train_t, _ = model.get_batches(
321
- train_data, BATCH_SIZE
322
  )
323
  dev_b, dev_m, dev_l, dev_t, _ = model.get_batches(
324
- dev_data, BATCH_SIZE
325
  )
326
  test_b, test_m, test_l, test_t, _ = model.get_batches(
327
- test_data, BATCH_SIZE
328
  )
329
 
330
  optimizer = optim.Adam(model.parameters(), lr=5e-5)
@@ -342,21 +177,14 @@ def test_wsd_accuracy(model_path):
342
  optimizer.step()
343
  model.zero_grad()
344
 
345
- # Evaluate dev
346
- c, t = _evaluate(
347
- model, dev_b, dev_m, dev_l, dev_t, device
348
- )
349
  dev_cors[epoch] += c
350
  dev_n[epoch] += t
351
 
352
- # Evaluate test
353
- c, t = _evaluate(
354
- model, test_b, test_m, test_l, test_t, device
355
- )
356
  test_cors[epoch] += c
357
  test_n[epoch] += t
358
 
359
- # Print per-lemma dev accuracy summary
360
  for epoch in range(MAX_EPOCHS):
361
  if dev_n[epoch] > 0:
362
  dev_acc = dev_cors[epoch] / dev_n[epoch]
@@ -365,7 +193,6 @@ def test_wsd_accuracy(model_path):
365
  f"lemma={lemma} n={dev_n[epoch]}"
366
  )
367
 
368
- # Find best dev epoch, report test accuracy at that epoch
369
  best_epoch = max(
370
  range(MAX_EPOCHS),
371
  key=lambda i: dev_cors[i] / dev_n[i] if dev_n[i] > 0 else 0,
 
9
  """
10
 
11
  import random
 
12
 
13
  import numpy as np
14
  import pytest
15
  import torch
 
 
16
  import torch.optim as optim
17
  from transformers import AutoTokenizer, BertModel
18
 
19
+ from case_study_utils import (
20
+ BATCH_SIZE,
21
+ BertForSequenceLabeling,
22
+ WSD_DATA_PATH,
23
+ )
24
+
25
  random.seed(1)
26
  torch.manual_seed(0)
27
  np.random.seed(0)
28
 
 
 
 
 
 
29
  REF_ACCURACY = 0.754
30
  TOLERANCE = 0.02 # WSD has more variance due to per-lemma training
 
 
 
31
  MAX_EPOCHS = 100
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def _get_labs(before, target, after, label):
35
  """Build a labeled sentence for WSD.
 
39
  """
40
  sent = []
41
  for word in before.split(" "):
42
+ if word:
43
  sent.append((word, -100))
44
  sent.append((target, label))
45
  for word in after.split(" "):
 
73
 
74
 
75
  def _get_splits(data):
76
+ """10-fold cross-validation splits."""
 
 
 
 
77
  trains, tests, devs = [], [], []
78
  for _i in range(10):
79
  trains.append([])
 
102
 
103
  def _evaluate(model, batched_data, batched_mask, batched_labels,
104
  batched_transforms, device):
105
+ """Evaluate model on batched data, return (correct, total)."""
 
 
 
 
106
  model.eval()
107
  cor = 0
108
  tot = 0
 
128
 
129
  @pytest.mark.slow
130
  def test_wsd_accuracy(model_path):
131
+ """Reproduce WSD case study from Bamman & Burns (2020)."""
 
 
 
 
 
 
132
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
 
134
  tokenizer = AutoTokenizer.from_pretrained(
135
  model_path, trust_remote_code=True
136
  )
137
+ data = _read_wsd_data(str(WSD_DATA_PATH))
138
 
139
  dev_cors = [0.0] * MAX_EPOCHS
140
  test_cors = [0.0] * MAX_EPOCHS
 
144
  for lemma_idx, lemma in enumerate(data):
145
  print(f"\n[{lemma_idx + 1}/{len(data)}] {lemma}")
146
 
 
147
  bert_model = BertModel.from_pretrained(model_path)
148
  model = BertForSequenceLabeling(
149
  tokenizer, bert_model, freeze_bert=False, num_labels=2
 
151
  model.to(device)
152
 
153
  trains, devs, tests = _get_splits(data[lemma])
 
 
 
154
 
155
  train_b, train_m, train_l, train_t, _ = model.get_batches(
156
+ trains[0], BATCH_SIZE
157
  )
158
  dev_b, dev_m, dev_l, dev_t, _ = model.get_batches(
159
+ devs[0], BATCH_SIZE
160
  )
161
  test_b, test_m, test_l, test_t, _ = model.get_batches(
162
+ tests[0], BATCH_SIZE
163
  )
164
 
165
  optimizer = optim.Adam(model.parameters(), lr=5e-5)
 
177
  optimizer.step()
178
  model.zero_grad()
179
 
180
+ c, t = _evaluate(model, dev_b, dev_m, dev_l, dev_t, device)
 
 
 
181
  dev_cors[epoch] += c
182
  dev_n[epoch] += t
183
 
184
+ c, t = _evaluate(model, test_b, test_m, test_l, test_t, device)
 
 
 
185
  test_cors[epoch] += c
186
  test_n[epoch] += t
187
 
 
188
  for epoch in range(MAX_EPOCHS):
189
  if dev_n[epoch] > 0:
190
  dev_acc = dev_cors[epoch] / dev_n[epoch]
 
193
  f"lemma={lemma} n={dev_n[epoch]}"
194
  )
195
 
 
196
  best_epoch = max(
197
  range(MAX_EPOCHS),
198
  key=lambda i: dev_cors[i] / dev_n[i] if dev_n[i] > 0 else 0,