cmboulanger commited on
Commit
c460c34
Β·
1 Parent(s): 406ca65

Add human-readable output for comparison of gold and llm-annotation

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. pyproject.toml +1 -0
  3. scripts/evaluate_llm.py +135 -77
  4. uv.lock +4 -0
.gitignore CHANGED
@@ -10,3 +10,6 @@ wheels/
10
  .venv
11
  .env*
12
  .DS_Store
 
 
 
 
10
  .venv
11
  .env*
12
  .DS_Store
13
+
14
+ # Local files
15
+ .local/
pyproject.toml CHANGED
@@ -23,6 +23,7 @@ markers = [
23
  dev = [
24
  "pytest>=8.0",
25
  "pytest-cov>=5.0",
 
26
  ]
27
 
28
  [build-system]
 
23
  dev = [
24
  "pytest>=8.0",
25
  "pytest-cov>=5.0",
26
+ "tqdm>=4.0",
27
  ]
28
 
29
  [build-system]
scripts/evaluate_llm.py CHANGED
@@ -25,6 +25,7 @@ from __future__ import annotations
25
  import argparse
26
  import json
27
  import os
 
28
  import sys
29
  import urllib.error
30
  import urllib.request
@@ -238,20 +239,31 @@ def run_evaluation(
238
  max_items: int | None,
239
  gliner_model: str | None = None,
240
  show_annotations: bool = False,
 
241
  ) -> bool:
242
  """
243
  Evaluate one provider: iterate over gold records with live progress,
244
  then print overall and per-element metrics.
 
 
 
 
245
  Returns True on success, False if a fatal exception occurred.
246
  """
 
 
247
  import warnings
248
  from lxml import etree
249
 
250
  from tei_annotator import preload_gliner_model
251
  from tei_annotator.evaluation import evaluate_element, aggregate, MatchMode
252
- from tei_annotator.evaluation.extractor import extract_spans
253
  from tei_annotator.inference.endpoint import EndpointCapability, EndpointConfig
254
 
 
 
 
 
 
255
  _TEI_NS = "http://www.tei-c.org/ns/1.0"
256
 
257
  mode_map = {
@@ -278,83 +290,119 @@ def run_evaluation(
278
  all_bibls = all_bibls[:max_items]
279
  n_total = len(all_bibls)
280
 
281
- sep = "─" * 64
282
- print(f"\n{sep}")
283
- print(f" Provider : {provider_name}")
284
- print(f" Gold file : {GOLD_FILE.relative_to(_REPO)}")
285
- print(f" Records : {n_total} match-mode: {match_mode_str}")
286
- print(f" GLiNER : {gliner_model or 'disabled'}")
287
- print(sep)
288
-
289
- if gliner_model:
290
- print(f" Loading GLiNER model '{gliner_model}'...", flush=True)
291
- preload_gliner_model(gliner_model)
292
- print(f" GLiNER model ready.")
293
-
294
- per_record = []
295
- failed = 0
296
- for i, bibl in enumerate(all_bibls, 1):
297
- plain_text = "".join(bibl.itertext())
298
- snippet = plain_text[:60].replace("\n", " ")
299
- print(f" [{i:3d}/{n_total}] {snippet}...", end="\r\n", flush=True)
300
- try:
301
- # Suppress the pipeline's best-effort XML validation warning here;
302
- # it surfaces again in the evaluator warning if parsing fails.
303
- with warnings.catch_warnings():
304
- warnings.filterwarnings(
305
- "ignore",
306
- message="Output XML validation failed",
307
- )
308
- result = evaluate_element(
309
- gold_element=bibl,
310
- schema=schema,
311
- endpoint=endpoint,
312
- gliner_model=gliner_model,
313
- match_mode=match_mode,
314
- )
315
- if show_annotations and result.annotation_xml is not None:
316
- sep60 = "─" * 60
317
- print(f"\n {sep60}")
318
- print(f" Annotation:")
319
- print(f" {result.annotation_xml}")
320
- print(f" F1={result.micro_f1:.3f} "
321
- f"missed={[s.element for s in result.unmatched_gold]} "
322
- f"spurious={[s.element for s in result.unmatched_pred]}")
323
- print(f" {sep60}\n")
324
- per_record.append(result)
325
- except Exception as exc:
326
- print(f"\n [{i:3d}/{n_total}] ERROR β€” {exc}")
327
- failed += 1
328
-
329
- # Clear the progress line
330
- print(" " * 70, end="\r")
331
-
332
- if not per_record:
333
- print(" βœ— All records failed β€” no results to report.")
334
- return False
335
-
336
- overall = aggregate(per_record)
337
- n_ok = len(per_record)
338
- print(f"\n Completed: {n_ok}/{n_total} records"
339
- + (f" ({failed} failed)" if failed else "") + "\n")
340
- print(overall.report(title=f"Overall β€” {provider_name}"))
341
-
342
- # Show the five worst records (by F1) for diagnostics
343
- worst = sorted(enumerate(per_record, 1), key=lambda x: x[1].micro_f1)[:5]
344
- if worst and worst[0][1].micro_f1 < 1.0:
345
- print(f"\n Lowest-F1 records (top 5):")
346
- for idx, r in worst:
347
- gold_bibl = all_bibls[idx - 1]
348
- snippet = "".join(gold_bibl.itertext())[:55].replace("\n", " ")
349
- fn_tags = [s.element for s in r.unmatched_gold]
350
- fp_tags = [s.element for s in r.unmatched_pred]
351
- print(
352
- f" #{idx:3d} F1={r.micro_f1:.3f}"
353
- f" missed={fn_tags} spurious={fp_tags}"
354
- )
355
- print(f' "{snippet}..."')
 
 
 
 
 
 
356
 
357
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
 
360
  # ---------------------------------------------------------------------------
@@ -395,6 +443,15 @@ def _parse_args() -> argparse.Namespace:
395
  default=False,
396
  help="Print the annotated XML output for each record (useful for inspection runs).",
397
  )
 
 
 
 
 
 
 
 
 
398
  p.add_argument(
399
  "--provider",
400
  choices=["gemini", "kisski", "all"],
@@ -443,6 +500,7 @@ def main() -> int:
443
  max_items=args.max_items,
444
  gliner_model=args.gliner_model,
445
  show_annotations=args.show_annotations,
 
446
  )
447
  results.append(ok)
448
 
 
25
  import argparse
26
  import json
27
  import os
28
+ import re
29
  import sys
30
  import urllib.error
31
  import urllib.request
 
239
  max_items: int | None,
240
  gliner_model: str | None = None,
241
  show_annotations: bool = False,
242
+ output_file: Path | None = None,
243
  ) -> bool:
244
  """
245
  Evaluate one provider: iterate over gold records with live progress,
246
  then print overall and per-element metrics.
247
+
248
+ When *output_file* is set all text output is written to that file and a
249
+ tqdm progress bar is shown in the terminal instead of per-record lines.
250
+
251
  Returns True on success, False if a fatal exception occurred.
252
  """
253
+ import contextlib
254
+ import io
255
  import warnings
256
  from lxml import etree
257
 
258
  from tei_annotator import preload_gliner_model
259
  from tei_annotator.evaluation import evaluate_element, aggregate, MatchMode
 
260
  from tei_annotator.inference.endpoint import EndpointCapability, EndpointConfig
261
 
262
+ try:
263
+ from tqdm import tqdm as _tqdm
264
+ except ImportError:
265
+ _tqdm = None
266
+
267
  _TEI_NS = "http://www.tei-c.org/ns/1.0"
268
 
269
  mode_map = {
 
290
  all_bibls = all_bibls[:max_items]
291
  n_total = len(all_bibls)
292
 
293
+ # --- output destination and progress display ----------------------------
294
+ # When --output-file: buffer all prints β†’ file; show tqdm bar on stderr.
295
+ # Otherwise: print to stdout and show manual per-record progress lines.
296
+ _buf = io.StringIO() if output_file else None
297
+ _pbar = (
298
+ _tqdm(total=n_total, desc="Annotating", unit="rec", file=sys.stderr)
299
+ if output_file and _tqdm
300
+ else None
301
+ )
302
+ if output_file and not _tqdm:
303
+ print("WARNING: tqdm not installed β€” no progress bar. Run: pip install tqdm",
304
+ file=sys.stderr)
305
+
306
+ _ok = False
307
+ with contextlib.redirect_stdout(_buf) if _buf else contextlib.nullcontext():
308
+ sep = "─" * 64
309
+ print(f"\n{sep}")
310
+ print(f" Provider : {provider_name}")
311
+ print(f" Gold file : {GOLD_FILE.relative_to(_REPO)}")
312
+ print(f" Records : {n_total} match-mode: {match_mode_str}")
313
+ print(f" GLiNER : {gliner_model or 'disabled'}")
314
+ print(sep)
315
+
316
+ if gliner_model:
317
+ print(f" Loading GLiNER model '{gliner_model}'...", flush=True)
318
+ preload_gliner_model(gliner_model)
319
+ print(f" GLiNER model ready.")
320
+
321
+ per_record = []
322
+ failed = 0
323
+ for i, bibl in enumerate(all_bibls, 1):
324
+ plain_text = "".join(bibl.itertext())
325
+ snippet = plain_text[:60].replace("\n", " ")
326
+ if _pbar:
327
+ _pbar.set_description(snippet[:45])
328
+ else:
329
+ print(f" [{i:3d}/{n_total}] {snippet}...", end="\r\n", flush=True)
330
+ try:
331
+ # Suppress the pipeline's best-effort XML validation warning here;
332
+ # it surfaces again in the evaluator warning if parsing fails.
333
+ with warnings.catch_warnings():
334
+ warnings.filterwarnings(
335
+ "ignore",
336
+ message="Output XML validation failed",
337
+ )
338
+ result = evaluate_element(
339
+ gold_element=bibl,
340
+ schema=schema,
341
+ endpoint=endpoint,
342
+ gliner_model=gliner_model,
343
+ match_mode=match_mode,
344
+ )
345
+ if show_annotations and result.annotation_xml is not None:
346
+ sep60 = "─" * 60
347
+ gold_parts = [bibl.text or ""]
348
+ for child in bibl:
349
+ child_xml = etree.tostring(child, encoding="unicode", with_tail=True)
350
+ gold_parts.append(re.sub(r'\s+xmlns(?::\w+)?="[^"]*"', "", child_xml))
351
+ gold_xml = "".join(gold_parts)
352
+ print(f"\n {sep60}")
353
+ print(f" Gold: {gold_xml}")
354
+ print(f" Annotation: {result.annotation_xml}")
355
+ print(f" F1={result.micro_f1:.3f} "
356
+ f"missed={[s.element for s in result.unmatched_gold]} "
357
+ f"spurious={[s.element for s in result.unmatched_pred]}")
358
+ print(f" {sep60}\n")
359
+ per_record.append(result)
360
+ if _pbar:
361
+ _pbar.update(1)
362
+ _pbar.set_postfix(F1=f"{result.micro_f1:.3f}")
363
+ except Exception as exc:
364
+ print(f"\n [{i:3d}/{n_total}] ERROR β€” {exc}")
365
+ failed += 1
366
+ if _pbar:
367
+ _pbar.update(1)
368
+
369
+ if _pbar:
370
+ _pbar.close()
371
+ else:
372
+ # Clear the progress line
373
+ print(" " * 70, end="\r")
374
 
375
+ if not per_record:
376
+ print(" βœ— All records failed β€” no results to report.")
377
+ else:
378
+ overall = aggregate(per_record)
379
+ n_ok = len(per_record)
380
+ print(f"\n Completed: {n_ok}/{n_total} records"
381
+ + (f" ({failed} failed)" if failed else "") + "\n")
382
+ print(overall.report(title=f"Overall β€” {provider_name}"))
383
+
384
+ # Show the five worst records (by F1) for diagnostics
385
+ worst = sorted(enumerate(per_record, 1), key=lambda x: x[1].micro_f1)[:5]
386
+ if worst and worst[0][1].micro_f1 < 1.0:
387
+ print(f"\n Lowest-F1 records (top 5):")
388
+ for idx, r in worst:
389
+ gold_bibl = all_bibls[idx - 1]
390
+ snippet = "".join(gold_bibl.itertext())[:55].replace("\n", " ")
391
+ fn_tags = [s.element for s in r.unmatched_gold]
392
+ fp_tags = [s.element for s in r.unmatched_pred]
393
+ print(
394
+ f" #{idx:3d} F1={r.micro_f1:.3f}"
395
+ f" missed={fn_tags} spurious={fp_tags}"
396
+ )
397
+ print(f' "{snippet}..."')
398
+
399
+ _ok = True
400
+
401
+ if _buf is not None:
402
+ output_file.write_text(_buf.getvalue(), encoding="utf-8")
403
+ print(f"\n Output written to: {output_file}")
404
+
405
+ return _ok
406
 
407
 
408
  # ---------------------------------------------------------------------------
 
443
  default=False,
444
  help="Print the annotated XML output for each record (useful for inspection runs).",
445
  )
446
+ p.add_argument(
447
+ "--output-file",
448
+ default=None,
449
+ metavar="PATH",
450
+ help=(
451
+ "Write all evaluation output to this file. "
452
+ "A tqdm progress bar is shown in the terminal instead of per-record lines."
453
+ ),
454
+ )
455
  p.add_argument(
456
  "--provider",
457
  choices=["gemini", "kisski", "all"],
 
500
  max_items=args.max_items,
501
  gliner_model=args.gliner_model,
502
  show_annotations=args.show_annotations,
503
+ output_file=Path(args.output_file) if args.output_file else None,
504
  )
505
  results.append(ok)
506
 
uv.lock CHANGED
@@ -918,6 +918,7 @@ wheels = [
918
  name = "regex"
919
  version = "2026.2.28"
920
  source = { registry = "https://pypi.org/simple" }
 
921
  wheels = [
922
  { url = "https://files.pythonhosted.org/packages/07/42/9061b03cf0fc4b5fa2c3984cbbaed54324377e440a5c5a29d29a72518d62/regex-2026.2.28-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fcf26c3c6d0da98fada8ae4ef0aa1c3405a431c0a77eb17306d38a89b02adcd7", size = 489574, upload-time = "2026-02-28T02:16:50.455Z" },
923
  { url = "https://files.pythonhosted.org/packages/77/83/0c8a5623a233015595e3da499c5a1c13720ac63c107897a6037bb97af248/regex-2026.2.28-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02473c954af35dd2defeb07e44182f5705b30ea3f351a7cbffa9177beb14da5d", size = 291426, upload-time = "2026-02-28T02:16:52.52Z" },
@@ -998,6 +999,7 @@ wheels = [
998
  { url = "https://files.pythonhosted.org/packages/6b/ca/d2c03b0efde47e13db895b975b2be6a73ed90b8ba963677927283d43bf74/regex-2026.2.28-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:1c2c95e1a2b0f89d01e821ff4de1be4b5d73d1f4b0bf679fa27c1ad8d2327f1a", size = 800366, upload-time = "2026-02-28T02:19:34.248Z" },
999
  { url = "https://files.pythonhosted.org/packages/14/bd/ee13b20b763b8989f7c75d592bfd5de37dc1181814a2a2747fedcf97e3ba/regex-2026.2.28-cp314-cp314t-win32.whl", hash = "sha256:bbb882061f742eb5d46f2f1bd5304055be0a66b783576de3d7eef1bed4778a6e", size = 274936, upload-time = "2026-02-28T02:19:36.313Z" },
1000
  { url = "https://files.pythonhosted.org/packages/cb/e7/d8020e39414c93af7f0d8688eabcecece44abfd5ce314b21dfda0eebd3d8/regex-2026.2.28-cp314-cp314t-win_amd64.whl", hash = "sha256:6591f281cb44dc13de9585b552cec6fc6cf47fb2fe7a48892295ee9bc4a612f9", size = 284779, upload-time = "2026-02-28T02:19:38.625Z" },
 
1001
  ]
1002
 
1003
  [[package]]
@@ -1132,6 +1134,7 @@ gliner = [
1132
  dev = [
1133
  { name = "pytest" },
1134
  { name = "pytest-cov" },
 
1135
  ]
1136
 
1137
  [package.metadata]
@@ -1147,6 +1150,7 @@ provides-extras = ["gliner"]
1147
  dev = [
1148
  { name = "pytest", specifier = ">=8.0" },
1149
  { name = "pytest-cov", specifier = ">=5.0" },
 
1150
  ]
1151
 
1152
  [[package]]
 
918
  name = "regex"
919
  version = "2026.2.28"
920
  source = { registry = "https://pypi.org/simple" }
921
+ sdist = { url = "https://files.pythonhosted.org/packages/8b/71/41455aa99a5a5ac1eaf311f5d8efd9ce6433c03ac1e0962de163350d0d97/regex-2026.2.28.tar.gz", hash = "sha256:a729e47d418ea11d03469f321aaf67cdee8954cde3ff2cf8403ab87951ad10f2", size = 415184, upload-time = "2026-02-28T02:19:42.792Z" }
922
  wheels = [
923
  { url = "https://files.pythonhosted.org/packages/07/42/9061b03cf0fc4b5fa2c3984cbbaed54324377e440a5c5a29d29a72518d62/regex-2026.2.28-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fcf26c3c6d0da98fada8ae4ef0aa1c3405a431c0a77eb17306d38a89b02adcd7", size = 489574, upload-time = "2026-02-28T02:16:50.455Z" },
924
  { url = "https://files.pythonhosted.org/packages/77/83/0c8a5623a233015595e3da499c5a1c13720ac63c107897a6037bb97af248/regex-2026.2.28-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02473c954af35dd2defeb07e44182f5705b30ea3f351a7cbffa9177beb14da5d", size = 291426, upload-time = "2026-02-28T02:16:52.52Z" },
 
999
  { url = "https://files.pythonhosted.org/packages/6b/ca/d2c03b0efde47e13db895b975b2be6a73ed90b8ba963677927283d43bf74/regex-2026.2.28-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:1c2c95e1a2b0f89d01e821ff4de1be4b5d73d1f4b0bf679fa27c1ad8d2327f1a", size = 800366, upload-time = "2026-02-28T02:19:34.248Z" },
1000
  { url = "https://files.pythonhosted.org/packages/14/bd/ee13b20b763b8989f7c75d592bfd5de37dc1181814a2a2747fedcf97e3ba/regex-2026.2.28-cp314-cp314t-win32.whl", hash = "sha256:bbb882061f742eb5d46f2f1bd5304055be0a66b783576de3d7eef1bed4778a6e", size = 274936, upload-time = "2026-02-28T02:19:36.313Z" },
1001
  { url = "https://files.pythonhosted.org/packages/cb/e7/d8020e39414c93af7f0d8688eabcecece44abfd5ce314b21dfda0eebd3d8/regex-2026.2.28-cp314-cp314t-win_amd64.whl", hash = "sha256:6591f281cb44dc13de9585b552cec6fc6cf47fb2fe7a48892295ee9bc4a612f9", size = 284779, upload-time = "2026-02-28T02:19:38.625Z" },
1002
+ { url = "https://files.pythonhosted.org/packages/13/c0/ad225f4a405827486f1955283407cf758b6d2fb966712644c5f5aef33d1b/regex-2026.2.28-cp314-cp314t-win_arm64.whl", hash = "sha256:dee50f1be42222f89767b64b283283ef963189da0dda4a515aa54a5563c62dec", size = 275010, upload-time = "2026-02-28T02:19:40.65Z" },
1003
  ]
1004
 
1005
  [[package]]
 
1134
  dev = [
1135
  { name = "pytest" },
1136
  { name = "pytest-cov" },
1137
+ { name = "tqdm" },
1138
  ]
1139
 
1140
  [package.metadata]
 
1150
  dev = [
1151
  { name = "pytest", specifier = ">=8.0" },
1152
  { name = "pytest-cov", specifier = ">=5.0" },
1153
+ { name = "tqdm", specifier = ">=4.0" },
1154
  ]
1155
 
1156
  [[package]]