Feature Extraction
Transformers
Sinhala
Hindi
English
tokenizer
WWHO
SGPE
linguis_trie
token
tokenization
Syllable
remeinium
transformer
linguistics
NLP
sinhala
hindi
english
BPE
GPE
Eval Results (legacy)
Instructions to use Remeinium/WWHO with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Remeinium/WWHO with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Remeinium/WWHO")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Remeinium/WWHO", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| WWHO(SGPE) GPE Trainer | |
| """ | |
| import argparse | |
| import gc | |
| import heapq | |
| import json | |
| import logging | |
| import os | |
| import pickle | |
| import re | |
| import time | |
| from collections import Counter, defaultdict | |
| from multiprocessing import Pool, cpu_count | |
| from tqdm import tqdm | |
| from router import CodeSwitchSegmenter | |
| from export import export_hf_tokenizer | |
| # ─── Logging ────── | |
| try: | |
| import psutil as _psutil | |
| def _ram_mb() -> str: | |
| p = _psutil.Process() | |
| rss = p.memory_info().rss / 1024**2 | |
| avail = _psutil.virtual_memory().available / 1024**2 | |
| return f"RSS={rss:.0f}MB avail={avail:.0f}MB" | |
| except ImportError: | |
| def _ram_mb() -> str: | |
| try: | |
| with open("/proc/meminfo") as f: | |
| info = {l.split(":")[0]: int(l.split()[1]) | |
| for l in f if ":" in l} | |
| avail = info.get("MemAvailable", 0) // 1024 | |
| return f"avail={avail}MB" | |
| except Exception: | |
| return "ram=N/A" | |
| _logger: logging.Logger | None = None | |
| def _log(msg: str): | |
| full = f"[{time.strftime('%H:%M:%S')}] [{_ram_mb()}] {msg}" | |
| print(full, flush=True) | |
| if _logger: | |
| _logger.info(full) | |
| def _setup_logging(output_dir: str): | |
| global _logger | |
| os.makedirs(output_dir, exist_ok=True) | |
| log_path = os.path.join(output_dir, "training.log") | |
| logging.basicConfig( | |
| filename=log_path, | |
| level=logging.INFO, | |
| format="%(message)s", | |
| ) | |
| _logger = logging.getLogger("wwho_trainer") | |
| _log(f"Log started: {log_path}") | |
| SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] | |
| # ─── Multiprocessing ────── | |
| _worker_segmenter: CodeSwitchSegmenter | None = None | |
| _worker_dfa_map: dict | None = None | |
| _worker_script_mode: str = "mixed" | |
| def _init_worker(script_mode: str): | |
| global _worker_segmenter, _worker_dfa_map, _worker_script_mode | |
| from linguis_trie import load_dfa_map | |
| _worker_script_mode = script_mode | |
| _worker_dfa_map = load_dfa_map(script_mode) | |
| language_blocks = {lang: dfa.unicode_blocks for lang, dfa in _worker_dfa_map.items()} | |
| _worker_segmenter = CodeSwitchSegmenter(language_blocks) | |
| def _pretokenize_line(text: str) -> list[str]: | |
| tokens: list[str] = [] | |
| for seg in _worker_segmenter.segment(text): | |
| if seg.language == "latin": | |
| tokens.append(seg.text) | |
| else: | |
| dfa = _worker_dfa_map.get(seg.language) | |
| if not dfa: | |
| tokens.append(seg.text) | |
| continue | |
| syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space) | |
| tokens.extend(syllables) | |
| return tokens | |
| def _is_boundary_token(token: str) -> bool: | |
| for ch in token: | |
| if _worker_segmenter: | |
| lang = _worker_segmenter._get_char_language(ch) | |
| if lang is not None and lang != "latin": | |
| return False | |
| return True | |
| def segment_into_words(syllables: list[str]) -> list[list[str]]: | |
| words: list[list[str]] = [] | |
| current: list[str] = [] | |
| for tok in syllables: | |
| if _is_boundary_token(tok): | |
| if current: | |
| words.append(current) | |
| current = [] | |
| words.append([tok]) | |
| else: | |
| if tok[0] in (' ', '\t', '\n', '\r') and current: | |
| words.append(current) | |
| current = [] | |
| current.append(tok) | |
| if current: | |
| words.append(current) | |
| return words | |
| # ─── Symbol Table ────── | |
| class SymbolTable: | |
| def __init__(self): | |
| self._str_to_id: dict[str, int] = {} | |
| self._id_to_str: list[str] = [] | |
| def get_or_add(self, token: str) -> int: | |
| if token in self._str_to_id: | |
| return self._str_to_id[token] | |
| new_id = len(self._id_to_str) | |
| self._str_to_id[token] = new_id | |
| self._id_to_str.append(token) | |
| return new_id | |
| def add_merged(self, a_id: int, b_id: int) -> int: | |
| merged_str = self._id_to_str[a_id] + self._id_to_str[b_id] | |
| return self.get_or_add(merged_str) | |
| def to_str(self, token_id: int) -> str: | |
| return self._id_to_str[token_id] | |
| def to_id(self, token: str) -> int | None: | |
| return self._str_to_id.get(token) | |
| def __len__(self) -> int: | |
| return len(self._id_to_str) | |
| # ─── GPETrainer ────── | |
| class GPETrainer: | |
| def __init__( | |
| self, | |
| vocab_size: int = 128_000, | |
| min_freq: int = 2, | |
| num_workers: int | None = None, | |
| checkpoint_every: int = 20_000, | |
| prune_freq: int = 100, | |
| script_mode: str = "mixed", | |
| ): | |
| self.target_vocab_size = vocab_size | |
| self.min_freq = min_freq | |
| self.num_workers = num_workers or max(1, cpu_count() - 1) | |
| self.checkpoint_every = checkpoint_every | |
| self.prune_freq = prune_freq | |
| self.script_mode = script_mode | |
| self.merges: list[tuple[int, int]] = [] | |
| self.symbols = SymbolTable() | |
| def stream_and_count( | |
| self, train_file: str, output_dir: str = "output" | |
| ) -> tuple[Counter, set[str]]: | |
| # ── 1. Count lines ────── | |
| print(" counting lines...", end=" ", flush=True) | |
| with open(train_file, "r", encoding="utf-8") as f: | |
| num_lines = sum(1 for _ in f) | |
| print(f"{num_lines:,}") | |
| CHUNK_SIZE = 5_000_000 | |
| BATCH = 4_096 | |
| partial_dir = os.path.join(output_dir, "_partial_counters") | |
| os.makedirs(partial_dir, exist_ok=True) | |
| _init_worker(self.script_mode) | |
| total_lines = 0 | |
| chunk_idx = 0 | |
| partial_paths: list[str] = [] | |
| PARTIAL_PRUNE = 2 | |
| def _save_partial(counter: Counter, idx: int, n_sent: int): | |
| if PARTIAL_PRUNE > 1: | |
| to_save = Counter( | |
| {k: v for k, v in counter.items() if v >= PARTIAL_PRUNE} | |
| ) | |
| else: | |
| to_save = counter | |
| pkl_path = os.path.join(partial_dir, f"partial_{idx:04d}.pkl") | |
| with open(pkl_path, "wb") as pf: | |
| pickle.dump(to_save, pf, protocol=pickle.HIGHEST_PROTOCOL) | |
| partial_paths.append(pkl_path) | |
| pkl_mb = os.path.getsize(pkl_path) / 1024**2 | |
| pbar.write( | |
| f" chunk {idx+1} done: {n_sent:,} sent " | |
| f"-> {len(to_save):,} word types (pruned from {len(counter):,}) " | |
| f"-> {pkl_path} ({pkl_mb:.0f} MB)" | |
| ) | |
| _log(f"CHUNK {idx+1} saved: {n_sent:,} sent, " | |
| f"{len(to_save):,} word types, {pkl_mb:.0f} MB") | |
| del to_save | |
| counter.clear() | |
| gc.collect() | |
| _log(f"CHUNK {idx+1} post-gc") | |
| chunk_counter: Counter = Counter() | |
| chunk_sent = 0 | |
| batch_buf: list[str] = [] | |
| pool = Pool( | |
| processes=self.num_workers, | |
| initializer=_init_worker, | |
| initargs=(self.script_mode,), | |
| ) | |
| with open(train_file, "r", encoding="utf-8") as f: | |
| pbar = tqdm(f, total=num_lines, unit=" sent", | |
| desc=f" pre-tokenizing [chunk 1]") | |
| for raw_line in pbar: | |
| try: | |
| obj = json.loads(raw_line) | |
| text = obj.get("text", "").strip() | |
| except json.JSONDecodeError: | |
| text = raw_line.strip() | |
| if not text: | |
| continue | |
| batch_buf.append(text) | |
| total_lines += 1 | |
| chunk_sent += 1 | |
| if len(batch_buf) >= BATCH: | |
| self._process_batch(pool, batch_buf, chunk_counter) | |
| batch_buf = [] | |
| if chunk_sent >= CHUNK_SIZE: | |
| if batch_buf: | |
| self._process_batch(pool, batch_buf, chunk_counter) | |
| batch_buf = [] | |
| pool.close() | |
| pool.join() | |
| pool = None | |
| gc.collect() | |
| _save_partial(chunk_counter, chunk_idx, chunk_sent) | |
| chunk_idx += 1 | |
| chunk_sent = 0 | |
| pbar.set_description( | |
| f" pre-tokenizing [chunk {chunk_idx + 1}]" | |
| ) | |
| gc.collect() | |
| pool = Pool( | |
| processes=self.num_workers, | |
| initializer=_init_worker, | |
| initargs=(self.script_mode,), | |
| ) | |
| if batch_buf: | |
| self._process_batch(pool, batch_buf, chunk_counter) | |
| pool.close() | |
| pool.join() | |
| gc.collect() | |
| if chunk_counter: | |
| _save_partial(chunk_counter, chunk_idx, chunk_sent) | |
| chunk_idx += 1 | |
| pbar.close() | |
| print(f" {total_lines:,} sentences -> {chunk_idx} chunks processed") | |
| # ── 3. Sequential merge with intermediate pruning ────── | |
| _log(f"MERGE START: {len(partial_paths)} partial counters, min_freq={self.min_freq}") | |
| N = len(partial_paths) | |
| word_counter: Counter = Counter() | |
| for i, pkl_path in enumerate(partial_paths): | |
| _log(f"MERGE [{i+1}/{N}] loading {pkl_path}") | |
| with open(pkl_path, "rb") as pf: | |
| partial: Counter = pickle.load(pf) | |
| _log(f"MERGE [{i+1}/{N}] loaded {len(partial):,} types, updating master...") | |
| word_counter.update(partial) | |
| del partial | |
| gc.collect() | |
| _log(f"MERGE [{i+1}/{N}] after update+gc: {len(word_counter):,} types") | |
| remaining = N - i - 1 | |
| safe_prune = max(1, self.min_freq - remaining) | |
| before = len(word_counter) | |
| if safe_prune > 1: | |
| word_counter = Counter( | |
| {k: v for k, v in word_counter.items() if v >= safe_prune} | |
| ) | |
| if i > 0 and i % 5 == 0: | |
| hard_threshold = max(2, self.min_freq // 2) | |
| word_counter = Counter( | |
| {k: v for k, v in word_counter.items() if v >= hard_threshold} | |
| ) | |
| _log(f"MERGE [{i+1}/{N}] HARD PRUNE TRIGGERED (threshold={hard_threshold})") | |
| gc.collect() | |
| pruned_n = before - len(word_counter) | |
| if pruned_n > 0: | |
| msg = (f" [{i+1}/{N}] merged -> {len(word_counter):,} types " | |
| f"(pruned {pruned_n:,})") | |
| print(msg, flush=True) | |
| _log(f"MERGE [{i+1}/{N}] post-prune: {len(word_counter):,} types " | |
| f"(removed {pruned_n:,})") | |
| else: | |
| print(f" [{i+1}/{N}] merged -> {len(word_counter):,} types", flush=True) | |
| _log(f"MERGE [{i+1}/{N}] no prune needed, {len(word_counter):,} types") | |
| os.remove(pkl_path) | |
| _log(f"MERGE [{i+1}/{N}] deleted {pkl_path}") | |
| try: | |
| os.rmdir(partial_dir) | |
| except OSError: | |
| pass | |
| n_types = len(word_counter) | |
| n_instances = sum(word_counter.values()) | |
| print(f"\n Final: {total_lines:,} sent -> {n_types:,} word types " | |
| f"({n_instances:,} instances)") | |
| return word_counter, set() | |
| def _process_batch( | |
| self, | |
| pool: Pool, | |
| batch: list[str], | |
| word_counter: Counter, | |
| ): | |
| syllable_streams = pool.map(_pretokenize_line, batch, chunksize=128) | |
| for stream in syllable_streams: | |
| words = segment_into_words(stream) | |
| for w in words: | |
| if not w: | |
| continue | |
| if not _is_boundary_token(w[0]): | |
| word_counter[tuple(w)] += 1 | |
| def compute_syllable_freqs(word_counter: Counter) -> Counter: | |
| syl_freq: Counter[str] = Counter() | |
| for word_tuple, word_freq in word_counter.items(): | |
| for syl in word_tuple: | |
| syl_freq[syl] += word_freq | |
| return syl_freq | |
| def build_word_types( | |
| self, | |
| word_counter: Counter, | |
| boundary_tokens: set[str], | |
| syl_freq: Counter | None = None, | |
| ) -> tuple[list[list[int]], list[int]]: | |
| UNK_SENTINEL = -1 | |
| pruned_set: set[str] = set() | |
| if syl_freq is not None and self.prune_freq > 0: | |
| for syl, freq in syl_freq.items(): | |
| if freq < self.prune_freq: | |
| pruned_set.add(syl) | |
| word_types: list[list[int]] = [] | |
| word_freqs: list[int] = [] | |
| pruned_word_count = 0 | |
| for word_tuple, freq in word_counter.items(): | |
| ids = [] | |
| for tok in word_tuple: | |
| if tok in pruned_set: | |
| ids.append(UNK_SENTINEL) | |
| else: | |
| ids.append(self.symbols.get_or_add(tok)) | |
| word_types.append(ids) | |
| word_freqs.append(freq) | |
| if UNK_SENTINEL in ids: | |
| pruned_word_count += 1 | |
| if pruned_set: | |
| print(f" pruned {len(pruned_set):,} rare syllables (freq < {self.prune_freq})") | |
| print(f" {pruned_word_count:,} word types contain [UNK] syllables") | |
| return word_types, word_freqs | |
| def build_token_index(word_types: list[list[int]]) -> dict[int, set[int]]: | |
| index: dict[int, set[int]] = defaultdict(set) | |
| for wt_idx, wt in enumerate(word_types): | |
| for tid in wt: | |
| if tid >= 0: | |
| index[tid].add(wt_idx) | |
| return dict(index) | |
| def count_all_pairs( | |
| self, | |
| word_types: list[list[int]], | |
| word_freqs: list[int], | |
| ) -> dict[tuple[int, int], int]: | |
| counts: dict[tuple[int, int], int] = defaultdict(int) | |
| for wt_idx, wt in enumerate(word_types): | |
| f = word_freqs[wt_idx] | |
| for i in range(len(wt) - 1): | |
| a, b = wt[i], wt[i + 1] | |
| if a < 0 or b < 0: | |
| continue | |
| counts[(a, b)] += f | |
| return dict(counts) | |
| def _build_heap(pair_counts: dict) -> list: | |
| heap = [(-freq, pair) for pair, freq in pair_counts.items() if freq > 0] | |
| heapq.heapify(heap) | |
| return heap | |
| def _heap_push(heap, pair, freq): | |
| if freq > 0: | |
| heapq.heappush(heap, (-freq, pair)) | |
| def _pop_best(self, heap, pair_counts): | |
| while heap: | |
| neg_freq, pair = heapq.heappop(heap) | |
| actual = pair_counts.get(pair, 0) | |
| if actual <= 0: | |
| continue | |
| if actual != -neg_freq: | |
| self._heap_push(heap, pair, actual) | |
| continue | |
| return pair, actual | |
| return None, 0 | |
| def merge_and_update( | |
| self, | |
| word_types: list[list[int]], | |
| word_freqs: list[int], | |
| pair: tuple[int, int], | |
| pair_counts: dict[tuple[int, int], int], | |
| token_index: dict[int, set[int]], | |
| merged_id: int, | |
| heap: list, | |
| ) -> int: | |
| a, b = pair | |
| total_applied = 0 | |
| candidates = list(token_index.get(a, set()) & token_index.get(b, set())) | |
| pair_counts.pop(pair, None) | |
| dirty_pairs: dict[tuple[int, int], int] = {} | |
| for wt_idx in candidates: | |
| wt = word_types[wt_idx] | |
| freq = word_freqs[wt_idx] | |
| if len(wt) < 2: | |
| continue | |
| new_wt: list[int] = [] | |
| i = 0 | |
| changed = False | |
| while i < len(wt): | |
| if i + 1 < len(wt) and wt[i] == a and wt[i + 1] == b: | |
| if new_wt and new_wt[-1] >= 0: | |
| lp = (new_wt[-1], a) | |
| pair_counts[lp] = pair_counts.get(lp, 0) - freq | |
| dirty_pairs[lp] = pair_counts[lp] | |
| if i + 2 < len(wt) and wt[i + 2] >= 0: | |
| rp = (b, wt[i + 2]) | |
| pair_counts[rp] = pair_counts.get(rp, 0) - freq | |
| dirty_pairs[rp] = pair_counts[rp] | |
| new_wt.append(merged_id) | |
| total_applied += freq | |
| changed = True | |
| if len(new_wt) >= 2 and new_wt[-2] >= 0: | |
| lp = (new_wt[-2], merged_id) | |
| pair_counts[lp] = pair_counts.get(lp, 0) + freq | |
| dirty_pairs[lp] = pair_counts[lp] | |
| if i + 2 < len(wt) and wt[i + 2] >= 0: | |
| rp = (merged_id, wt[i + 2]) | |
| pair_counts[rp] = pair_counts.get(rp, 0) + freq | |
| dirty_pairs[rp] = pair_counts[rp] | |
| i += 2 | |
| else: | |
| new_wt.append(wt[i]) | |
| i += 1 | |
| if changed: | |
| word_types[wt_idx] = new_wt | |
| if merged_id not in token_index: | |
| token_index[merged_id] = set() | |
| token_index[merged_id].add(wt_idx) | |
| remaining = set(new_wt) | |
| if a not in remaining and wt_idx in token_index.get(a, set()): | |
| token_index[a].discard(wt_idx) | |
| if b not in remaining and wt_idx in token_index.get(b, set()): | |
| token_index[b].discard(wt_idx) | |
| for tok_id in (a, b): | |
| if tok_id in token_index and not token_index[tok_id]: | |
| del token_index[tok_id] | |
| for p, cnt in dirty_pairs.items(): | |
| if cnt <= 0: | |
| pair_counts.pop(p, None) | |
| else: | |
| self._heap_push(heap, p, cnt) | |
| return total_applied | |
| def save_checkpoint(self, step: int, output_dir: str, elapsed: float): | |
| merge_strs = [ | |
| [self.symbols.to_str(a), self.symbols.to_str(b)] | |
| for a, b in self.merges | |
| ] | |
| ckpt = { | |
| "step": step, | |
| "script_mode": self.script_mode, | |
| "merges": merge_strs, | |
| "elapsed_seconds": round(elapsed, 1), | |
| } | |
| path = os.path.join(output_dir, f"checkpoint_{step}.json") | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(ckpt, f, ensure_ascii=False) | |
| size_mb = os.path.getsize(path) / (1024 * 1024) | |
| return path, size_mb | |
| def load_checkpoint(self, ckpt_path: str): | |
| with open(ckpt_path, "r", encoding="utf-8") as f: | |
| ckpt = json.load(f) | |
| print(f" loaded checkpoint: step {ckpt['step']}, " | |
| f"{len(ckpt['merges'])} merges, " | |
| f"{ckpt['elapsed_seconds']:.1f}s elapsed") | |
| return ckpt | |
| def replay_merges(self, merge_strs, word_types, word_freqs, token_index, pair_counts): | |
| print(f" replaying {len(merge_strs)} merges...", flush=True) | |
| t0 = time.time() | |
| dummy_heap: list = [] | |
| for a_str, b_str in tqdm(merge_strs, desc=" replaying", unit=" merge"): | |
| a_id = self.symbols.to_id(a_str) | |
| b_id = self.symbols.to_id(b_str) | |
| if a_id is None or b_id is None: | |
| continue | |
| merged_id = self.symbols.add_merged(a_id, b_id) | |
| self.merges.append((a_id, b_id)) | |
| self.merge_and_update( | |
| word_types, word_freqs, (a_id, b_id), pair_counts, | |
| token_index, merged_id, dummy_heap, | |
| ) | |
| print(f" replayed {len(self.merges)} merges in {time.time()-t0:.1f}s") | |
| def train(self, train_file: str, output_dir: str = "output", | |
| resume_path: str | None = None): | |
| os.makedirs(output_dir, exist_ok=True) | |
| print(f"WWHO (SGPE) GPE Trainer — script_mode={self.script_mode}, " | |
| f"workers={self.num_workers}") | |
| print(f"Training file: {train_file}\n") | |
| print("[1/5] Streaming pre-tokenization (CodeSwitchRouter)...") | |
| t_start = time.time() | |
| word_counter, boundary_tokens = self.stream_and_count(train_file, output_dir) | |
| print("\n[2/5] Building ID corpus...") | |
| syl_freq = None | |
| if self.prune_freq > 0: | |
| syl_freq = self.compute_syllable_freqs(word_counter) | |
| total_syls = len(syl_freq) | |
| surviving = sum(1 for f in syl_freq.values() if f >= self.prune_freq) | |
| print(f" syllable pruning: {total_syls:,} unique syllables, " | |
| f"{surviving:,} survive (freq >= {self.prune_freq})") | |
| word_types, word_freqs = self.build_word_types( | |
| word_counter, boundary_tokens, syl_freq=syl_freq, | |
| ) | |
| del word_counter, syl_freq | |
| base_vocab = len(self.symbols) | |
| total_instances = sum(word_freqs) | |
| print(f" base vocab (syllables + boundaries): {base_vocab:,}") | |
| print(f" word types: {len(word_types):,} ({total_instances:,} instances)") | |
| print("\n[3/5] Building index and counting pairs...") | |
| token_index = self.build_token_index(word_types) | |
| pair_counts = self.count_all_pairs(word_types, word_freqs) | |
| print(f" {len(pair_counts):,} unique pairs") | |
| start_step = 0 | |
| elapsed_prior = 0.0 | |
| if resume_path: | |
| print(f"\n Resuming from {resume_path}...") | |
| ckpt = self.load_checkpoint(resume_path) | |
| self.replay_merges( | |
| ckpt["merges"], word_types, word_freqs, token_index, pair_counts, | |
| ) | |
| start_step = ckpt["step"] | |
| elapsed_prior = ckpt["elapsed_seconds"] | |
| pair_counts = self.count_all_pairs(word_types, word_freqs) | |
| print(f" rebuilt pair counts: {len(pair_counts):,} unique pairs") | |
| total_vocab_needed = self.target_vocab_size - len(SPECIAL_TOKENS) | |
| num_merges = max(0, total_vocab_needed - base_vocab) | |
| remaining = num_merges - start_step | |
| print(f"\n merge budget: {num_merges:,} " | |
| f"(starting at {start_step}, {remaining:,} remaining, min_freq={self.min_freq})") | |
| print(f"\n[4/5] Merge loop...") | |
| heap = self._build_heap(pair_counts) | |
| t0 = time.time() | |
| pbar = tqdm(range(start_step + 1, num_merges + 1), | |
| desc=" merging", unit=" merge") | |
| for step in pbar: | |
| best_pair, freq = self._pop_best(heap, pair_counts) | |
| if best_pair is None or freq < self.min_freq: | |
| pbar.write(f" stopping at step {step}: " | |
| f"{'no pairs' if best_pair is None else f'freq={freq} < {self.min_freq}'}") | |
| break | |
| a_id, b_id = best_pair | |
| merged_id = self.symbols.add_merged(a_id, b_id) | |
| self.merges.append(best_pair) | |
| n_applied = self.merge_and_update( | |
| word_types, word_freqs, best_pair, pair_counts, | |
| token_index, merged_id, heap, | |
| ) | |
| if step <= 10 or step % 1000 == 0: | |
| a_s = self.symbols.to_str(a_id) | |
| b_s = self.symbols.to_str(b_id) | |
| m_s = self.symbols.to_str(merged_id) | |
| elapsed = time.time() - t0 + elapsed_prior | |
| pbar.write(f" [{step:>7}/{num_merges}] " | |
| f"'{a_s}' + '{b_s}' -> '{m_s}' " | |
| f"(freq={freq:,}, applied={n_applied:,}) [{elapsed:.1f}s]") | |
| if self.checkpoint_every > 0 and step % self.checkpoint_every == 0: | |
| elapsed = time.time() - t0 + elapsed_prior | |
| path, sz = self.save_checkpoint(step, output_dir, elapsed) | |
| pbar.write(f" >> checkpoint: {path} ({sz:.2f} MB)") | |
| pbar.set_postfix(freq=freq, vocab=len(self.symbols)) | |
| pbar.close() | |
| merge_elapsed = time.time() - t0 | |
| total_elapsed = merge_elapsed + elapsed_prior | |
| print(f" done: {len(self.merges)} merges in {merge_elapsed:.1f}s " | |
| f"(total {total_elapsed:.1f}s)") | |
| print("\n[5/5] Building vocabulary and exporting...") | |
| self._save_output(word_types, word_freqs, boundary_tokens, output_dir) | |
| wall = time.time() - t_start | |
| print(f"\nTotal wall time: {wall:.1f}s ({wall/60:.1f} min)") | |
| def _save_output(self, word_types, word_freqs, boundary_tokens, output_dir): | |
| final_freq: Counter[int] = Counter() | |
| for wt_idx, wt in enumerate(word_types): | |
| f = word_freqs[wt_idx] | |
| for tid in wt: | |
| if tid >= 0: | |
| final_freq[tid] += f | |
| vocab: dict[str, int] = {} | |
| for i, st in enumerate(SPECIAL_TOKENS): | |
| vocab[st] = i | |
| next_id = len(SPECIAL_TOKENS) | |
| for tid, _ in final_freq.most_common(): | |
| if len(vocab) >= self.target_vocab_size: | |
| break | |
| tok_str = self.symbols.to_str(tid) | |
| if tok_str not in vocab: | |
| vocab[tok_str] = next_id | |
| next_id += 1 | |
| for sid in range(len(self.symbols)): | |
| if len(vocab) >= self.target_vocab_size: | |
| break | |
| s = self.symbols.to_str(sid) | |
| if s not in vocab: | |
| vocab[s] = next_id | |
| next_id += 1 | |
| print(f" vocab size: {len(vocab):,}") | |
| print(f" merge rules: {len(self.merges):,}") | |
| merge_strs = [ | |
| [self.symbols.to_str(a), self.symbols.to_str(b)] | |
| for a, b in self.merges | |
| ] | |
| output = { | |
| "version": "wwho_sgpe", | |
| "script_mode": self.script_mode, | |
| "vocab_size": len(vocab), | |
| "special_tokens": SPECIAL_TOKENS, | |
| "num_merges": len(self.merges), | |
| "prune_freq": self.prune_freq, | |
| "leading_space": True, | |
| "merges": merge_strs, | |
| "vocab": vocab, | |
| } | |
| path = os.path.join(output_dir, "vocab.json") | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(output, f, ensure_ascii=False, indent=2) | |
| size_mb = os.path.getsize(path) / (1024 * 1024) | |
| print(f" saved: {path} ({size_mb:.2f} MB)") | |
| self.save_checkpoint(len(self.merges), output_dir, 0) | |
| hf_path = os.path.join(output_dir, "tokenizer.json") | |
| export_hf_tokenizer(vocab, merge_strs, SPECIAL_TOKENS, hf_path, | |
| script_mode=self.script_mode) | |
| print(f"\n{'='*60}") | |
| print(f"TRAINING COMPLETE — WWHO") | |
| print(f" Script mode: {self.script_mode}") | |
| print(f" Vocab size: {len(vocab):,}") | |
| print(f" Merge rules: {len(self.merges):,}") | |
| print(f" Word types: {len(word_types):,}") | |
| print(f"{'='*60}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="WWHO (SGPE) GPE Trainer") | |
| parser.add_argument("--train_file", type=str, default="dataset/mixed_train.jsonl") | |
| parser.add_argument("--vocab_size", type=int, default=128_000, | |
| help="Target SGPE vocab size (default 128K)") | |
| parser.add_argument("--min_freq", type=int, default=2) | |
| parser.add_argument("--prune_freq", type=int, default=100, | |
| help="Drop syllables below this corpus frequency to [UNK]") | |
| parser.add_argument("--output_dir", type=str, default="output") | |
| parser.add_argument("--num_workers", type=int, default=None) | |
| parser.add_argument("--checkpoint_every", type=int, default=20_000) | |
| parser.add_argument("--resume", type=str, default=None) | |
| parser.add_argument("--script_mode", type=str, default="mixed", | |
| choices=["sinhala", "devanagari", "mixed"], | |
| help="Which Indic script(s) to merge in BPE " | |
| "(English/code always stays as boundary tokens)") | |
| args = parser.parse_args() | |
| _setup_logging(args.output_dir) | |
| _log(f"Starting WWHO (SGPE) trainer: train_file={args.train_file} " | |
| f"vocab_size={args.vocab_size} script_mode={args.script_mode} " | |
| f"prune_freq={args.prune_freq} min_freq={args.min_freq}") | |
| trainer = GPETrainer( | |
| vocab_size=args.vocab_size, | |
| min_freq=args.min_freq, | |
| num_workers=args.num_workers, | |
| checkpoint_every=args.checkpoint_every, | |
| prune_freq=args.prune_freq, | |
| script_mode=args.script_mode, | |
| ) | |
| trainer.train(args.train_file, args.output_dir, resume_path=args.resume) | |
| if __name__ == "__main__": | |
| main() | |