#!/usr/bin/env python3 """ NeoLLM model with FANformer, RMSNorm, ResFormer, Learnable Multipliers, full attention augmented with optional Momentum, MEA, and LUCID operators, Gated Attention (Qiu et al., 2025) combined with Affine-Scaled Attention (Bae et al., 2026), an optional Leviathan continuous token embedding generator, an optional Leviathan-JTok-M token-indexed modulation module, optional Spelling Bee Embeddings (Rabe et al., 2026), optional Context Re-Positioning (Li et al., 2026), optional REPO-GRAPE contextual group positioning (Li et al., 2026 + Zhang et al., 2026), optional GOAT-style factorised attention log-priors (Litman & Guo, 2026), and optional StackMemory (Zhang et al., NeurIPS 2025). Attention stack (orthogonal, all active simultaneously when enabled): 1. Gated Attention (use_gated_attention implicit via q_proj gate chunk): applies a head-specific elementwise sigmoid gate to the concatenated SDPA output before o_proj (G1 position, Qiu et al. 2025 §2.2). Introduces non-linearity between W_V and W_O, sparse input-dependent gating, and eliminates attention sink. 2. Affine-Scaled Attention (use_affine_scaled_attention): modulates softmax attention weights directly as [α(X)·softmax(QK^T/√dk) + β(X)] V relaxing the unit-sum constraint of softmax. α is per-head, per-query, input-dependent and bounded in [0,1] via linear_clipping. β is a moving-average bias that prevents collapse. Reduces first-token bias, increases attention entropy, and is complementary to Gated Attention (Bae et al. 2026, Table 2: Affine-Scaled + Gated > either alone). Flash/SDPA path: Expanding [α·softmax(QKᵀ)+β]V distributively yields two terms: α · attention(Q,K,V) — backend computes this directly + β · Σ_{j∈A(i)} V_j — sum over the same valid keys A(i) For global causal attention A(i) is a prefix and this is V.cumsum. For sliding-window or packed sequences the sum is windowed or segmented so β cannot leak outside the backend attention mask. Per-weight tensors (attn_weights_pre/post_affine) are unavailable in flash mode. Eager path: applies the same valid-key mask to α·softmax + β and keeps full weight access for interpretability. References: FANformer: "FANformer: Improving Large Language Models Through Effective Periodicity Modeling" Learnable Multipliers: "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers" Gated Attention: Qiu et al. (2025). "Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free." arXiv:2505.06708. Affine-Scaled Attention: Bae et al. (2026). "Affine-Scaled Attention: Towards Flexible and Stable Transformer Attention." arXiv:2602.23057. Leviathan Generator: Batley & Saha (2026). "A Separable Architecture for Continuous Token Representation in Language Models." arXiv:2601.22040. KHRONOS: Batley & Saha (2025). "KHRONOS: a Kernel-Based Neural Architecture for Rapid, Resource-Efficient Scientific Computation." arXiv:2505.13315. JTok / JTok-M: Yang et al. (2026). "JTok: On Token Embedding as Another Axis of Scaling Law via Joint Token Self-Modulation." arXiv:2602.00800. Spelling Bee Embeddings: Rabe, Clymo & Dong (2026). "Spelling Bee Embeddings for Language Modeling." arXiv:2601.18030. Context Re-Positioning: Li, Zhao, Cai & Sproat (2026). "REPO: Language Models with Context Re-Positioning." arXiv:2512.14391. GRAPE: Zhang et al. (2026). "Group Representational Position Encoding." arXiv:2512.07805 / ICLR 2026. GOAT priors: Litman & Guo (2026). "You Need Better Attention Priors." arXiv:2601.15380. StackMemory / STACKTRANS: Zhang et al. (2025). "Recursive Transformer: Boosting Reasoning Ability with State Stack." NeurIPS 2025. """ import math from typing import Callable, List, Optional, Union, Tuple import torch import torch.nn.functional as F from torch import nn from cut_cross_entropy import linear_cross_entropy from torch.utils.checkpoint import checkpoint from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging from .configuration_neollm import NeoLLMConfig from transformers import AutoConfig, AutoModel, AutoModelForCausalLM torch._dynamo.config.capture_scalar_outputs = True logger = logging.get_logger(__name__) # ── Optional flash_attn direct import (required only for IHA seq-expand, P>1) ── # Attempted once at module load time so the symbols are available when # NeoLLMAttention.__init__ stores them as instance attributes. If flash_attn # is not installed the names are set to None and a clear ImportError is raised # inside __init__ only when the user actually enables IHA with P>1. # The common path (P=1 or use_iha=False) never triggers the error. # # Functions surfaced: # flash_attn_func – standard batched causal attention (no padding path) # flash_attn_varlen_func – variable-length / packed-sequence causal attention # # Reference: Dao-AILab/flash-attention, flash_attn.flash_attn_interface try: from flash_attn.flash_attn_interface import ( # type: ignore[import] flash_attn_func as _IHA_FA_FUNC, flash_attn_varlen_func as _IHA_FA_VARLEN, ) _IHA_FLASH_ATTN_AVAILABLE = True except ImportError: _IHA_FA_FUNC = None _IHA_FA_VARLEN = None _IHA_FLASH_ATTN_AVAILABLE = False class ScalarMultiplier(nn.Module): """ Scalar Learnable Multiplier: W̃ = s·W From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the WD-noise equilibrium that constrains ||W|| ∝ √(η/λ). """ def __init__(self, initial_value: float = 1.0): super().__init__() self.multiplier = nn.Parameter(torch.tensor(initial_value)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.multiplier * x class VectorMultiplier(nn.Module): """ Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c) From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers": Frees not only the overall matrix norm but also individual row/column norms from the WD-noise equilibrium, enabling richer feature scale diversity. """ def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0): super().__init__() self.multiplier_type = multiplier_type self.multiplier = nn.Parameter(torch.ones(dim) * initial_value) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.multiplier class LinearWithMultipliers(nn.Module): """ Linear layer with optional row and/or column learnable multipliers. Implements: y = (r ⊙ (W @ (c ⊙ x))) + b when enabled. With ``enable_multipliers=False`` no multiplier parameters are instantiated and the module reduces to its wrapped ``nn.Linear``. """ def __init__( self, in_features: int, out_features: int, bias: bool = True, use_row_multiplier: bool = False, use_column_multiplier: bool = False, enable_multipliers: bool = True, ): super().__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) self.enable_multipliers = bool(enable_multipliers) self.use_row_multiplier = bool(use_row_multiplier and self.enable_multipliers) self.use_column_multiplier = bool(use_column_multiplier and self.enable_multipliers) if self.use_row_multiplier: self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row") if self.use_column_multiplier: self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column") def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_column_multiplier: x = self.column_multiplier(x) x = self.linear(x) if self.use_row_multiplier: x = self.row_multiplier(x) return x class EmbeddingWithMultipliers(nn.Module): """ Token embedding matrix with optional vocabulary-row and hidden-channel learnable multipliers. Effective embedding: E_eff[token, channel] = r_token * E[token, channel] * c_channel This mirrors the Learnable Multipliers paper's embedding recommendation, while keeping the actual weight available as ``.weight`` for tooling that expects a standard ``nn.Embedding``-like object. It is only intended for the untied, non-generator embedding path. """ def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, enable_multipliers: bool = True, ): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.enable_multipliers = bool(enable_multipliers) self.use_row_multiplier = self.enable_multipliers self.use_column_multiplier = self.enable_multipliers if self.enable_multipliers: self.row_multiplier = nn.Parameter(torch.ones(num_embeddings)) self.column_multiplier = VectorMultiplier(embedding_dim, multiplier_type="column") @property def weight(self) -> torch.nn.Parameter: return self.embedding.weight def forward(self, input_ids: torch.Tensor) -> torch.Tensor: x = self.embedding(input_ids) if self.use_column_multiplier: x = self.column_multiplier(x) if self.use_row_multiplier: row_scale = F.embedding(input_ids, self.row_multiplier.unsqueeze(-1)).to(dtype=x.dtype) x = x * row_scale return x # ==================== LEVIATHAN CONTINUOUS TOKEN GENERATOR ==================== class LeviathanGenerator(nn.Module): """ Continuous token embedding generator for the Leviathan architecture. Replaces E ∈ R^{V×D} with a separable generator G : {0,...,V-1} → R^D. Three-stage pipeline: latent compositional indexing → B-spline basis expansion → tensor-product aggregation (Batley & Saha, 2026, §3.1). When ``return_internals=True`` the forward returns ``(embeddings, z_tilde, B_vals)`` for reuse by JTok-M surfaces in every decoder layer, avoiding redundant B-spline evaluation. B-spline basis uses the KHRONOS closed-form quadratic kernel (ckhronos.py) and is explicitly normalized across the knot dimension after evaluation, matching the Leviathan reference implementation described by the authors. This keeps every per-dimension basis vector on a partition-of-unity scale even near the finite grid boundaries. Sign parity tracking matches KHRONOS KHRONOSLayer exactly. Initialization: spline_weight_delta ~ normal(mean=0.0, std=0.1), with effective spline weights computed as (1 + spline_weight_delta). This matches the authors' ``1 + wd_i`` parameterization so phi ≈ 1.0 at init and the product of d_seed factors starts near 1.0 instead of ~10^{-21}. FP8 note: Leviathan deliberately stores the shared JTok-M seed projection as raw Parameters rather than nn.Linear. This keeps the generator outside TorchAO Float8Linear conversion even if an external FP8 filter is too broad. """ def __init__(self, config: NeoLLMConfig): super().__init__() vocab_size = config.vocab_size hidden_size = config.hidden_size d_seed = config.generator_d_seed num_modes = config.generator_num_modes num_knots = config.generator_num_knots spline_degree = config.generator_spline_degree k = config.generator_k krank = getattr(config, "generator_krank", 64) b = math.ceil(vocab_size ** (1.0 / k)) self.b = b self.k = k self.d_seed = d_seed self.num_modes = num_modes self.num_knots = num_knots self.spline_degree = spline_degree self.krank = krank self.hidden_size = hidden_size # ── Stage 1: shared codebook lookup ────────────────────────────── # Produces z [N, d_seed] — the raw seed before any per-head # preprocessing. This is the only shared computation across heads. self.codebooks = nn.Parameter(torch.empty(k, b, d_seed)) # Shared knot grid — fixed, not learned. # Used by both the generator heads and the JTok-M shared path. self.register_buffer( "knot_grid", torch.linspace(0.0, 1.0, num_knots), persistent=False, ) # ── JTok-M shared path ──────────────────────────────────────────── # seed_proj + seed_norm produce z_tilde [N, d_seed] used by JTok-M # surfaces in every decoder layer. This path is kept separate from # the per-head generator path so JTok-M is completely unaffected # by the generator architecture change. # Stored as raw Parameters instead of nn.Linear so external TorchAO # FP8 conversion cannot wrap this numerically sensitive projection in # Float8Linear. The forward still computes the same affine map via # F.linear(z, seed_proj_weight, seed_proj_bias). self.seed_proj_weight = nn.Parameter(torch.empty(d_seed, d_seed)) self.seed_proj_bias = nn.Parameter(torch.empty(d_seed)) self.seed_norm = nn.LayerNorm(d_seed) # ── Per-head generator (fused, vectorized) ─────────────────────── # Mathematically identical to 8 independent heads but fused into # single tensors so the entire per-head path executes in 6 kernels # instead of 8×5=40, and the maximum intermediate tensor appears # once instead of 8 times — eliminating the fragmentation that caused # OOM during backward. # # head_proj_weight [num_modes*d_seed, d_seed]: # Replaces ModuleList of 8 Linear(d_seed, d_seed, bias=False). # Forward: z @ W^T → [N, M*d_seed] → reshape [N, M, d_seed]. # Gradient is identical: each [d_seed, d_seed] block receives # gradient only from its own head output. # # head_norm_weight [num_modes, d_seed], head_norm_bias [num_modes, d_seed]: # Replace ModuleList of 8 LayerNorm(d_seed) with independent # weight/bias per head. Manual LN formula over last dim preserves # exact per-head normalization semantics. # # head_scale [num_modes, d_seed]: unchanged, already fused. # # head_spline_delta [num_modes, d_seed, num_knots, krank]: learned # delta coefficients wd_i. The effective spline coefficient used in # the manifold block is (1 + wd_i), matching the authors' stable # multiplicative parameterization. # # head_out_weight [num_modes, krank, hidden_size]: # Replaces ModuleList of 8 Linear(krank, hidden_size, bias=False). # Forward: einsum("nmk,mkd->nd", modes, W) — all heads projected # and summed in a single kernel. self.head_proj_weight = nn.Parameter( torch.empty(num_modes * d_seed, d_seed) ) self.head_norm_weight = nn.Parameter(torch.ones(num_modes, d_seed)) self.head_norm_bias = nn.Parameter(torch.zeros(num_modes, d_seed)) self.head_norm_eps = 1e-5 # head_scale: [num_modes, d_seed], initialized to (num_knots - 1) self.head_scale = nn.Parameter( torch.full((num_modes, d_seed), float(num_knots - 1)) ) # head_spline_delta: [num_modes, d_seed, num_knots, krank] # Effective coefficient is (1 + head_spline_delta). self.head_spline_delta = nn.Parameter( torch.empty(num_modes, d_seed, num_knots, krank) ) self.head_out_weight = nn.Parameter( torch.empty(num_modes, krank, hidden_size) ) def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor: """ Deterministic base-b decomposition: i → (i_0, ..., i_{k-1}). Maps token indices directly to codebook coordinates via arithmetic: token x → (x // b^{k-1}, ..., x % b). """ ids = token_ids.long().clone() coords = torch.empty( *token_ids.shape, self.k, dtype=torch.long, device=token_ids.device, ) for r in range(self.k - 1, -1, -1): coords[..., r] = ids % self.b ids = ids // self.b return coords @staticmethod def _normalize_bspline_basis(B: torch.Tensor) -> torch.Tensor: """ Explicitly normalize a quadratic B-spline basis across knot points. The finite-grid quadratic kernel does not sum to exactly 1 near the boundaries of [0, 1]. The Leviathan reference implementation divides by the post-evaluation sum across knots, so every basis vector satisfies sum_g B[..., g] ≈ 1 before the separable tensor-product aggregation. """ denom = B.sum(dim=-1, keepdim=True).clamp_min(1e-12) return B / denom def _bspline_basis(self, x_flat: torch.Tensor) -> torch.Tensor: """ KHRONOS quadratic B-spline basis with fixed scalar scale. Used exclusively by the JTok-M shared path (z_tilde → B_vals). JTok-M surfaces have their own spline_coeff and call _modes_from_basis with these B_vals. The basis is normalized across knots after evaluation so the shared geometry matches the Leviathan generator. Args: x_flat: [N, d_seed], values in [0, 1]. Returns: [N, d_seed, num_knots] float32. """ scale = float(self.num_knots - 1) x32 = x_flat.float() x_e = x32.unsqueeze(-1) grid = self.knot_grid.float().view(1, 1, -1) d = (x_e - grid).abs() * scale B = torch.where( d < 0.5, 0.75 - d ** 2, torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)), ) # [N, d_seed, num_knots] float32 return self._normalize_bspline_basis(B) def _bspline_basis_all_heads( self, x_all: torch.Tensor, ) -> torch.Tensor: """ Vectorized KHRONOS quadratic B-spline basis for all heads at once. Mathematically identical to calling _bspline_basis_head 8 times in a loop, but materializes the full [N, M, d_seed, n_knots] tensor in a single kernel instead of 8 sequential [N, d_seed, n_knots] tensors. Args: x_all: [N, M, d_seed], values in [0, 1], all heads stacked. Returns: [N, M, d_seed, n_knots] float32. NOTE: Este método se mantiene para compatibilidad con JTok-M y análisis. El forward del generator ya NO lo usa — usa _compute_head en su lugar. """ x32 = x_all.float() x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1] grid = self.knot_grid.float().view(1, 1, 1, -1) # [1, 1, 1, n_knots] # head_scale [M, d_seed] → [1, M, d_seed, 1] sc = self.head_scale.float().unsqueeze(0).unsqueeze(-1) d = (x_e - grid).abs() * sc # [N, M, d_seed, n_knots] B = torch.where( d < 0.5, 0.75 - d ** 2, torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)), ) # [N, M, d_seed, n_knots] float32 return self._normalize_bspline_basis(B) def _compute_head( self, z: torch.Tensor, m: int, ) -> torch.Tensor: """ Forward completo para el cabezal m del generator, sin particionar la dimensión ``d_seed`` en chunks. Matemática aplicada directamente: phi[n, d, k] = Σ_g B[n, d, g] · (1 + wd[m, d, g, k]) modes[n, k] = Π_d phi[n, d, k] out[n, :] = modes[n, :] @ W_out[m] Esta versión materializa ``phi`` completo con forma ``[N, d_seed, krank]`` para cada cabezal. Es más directa y elimina el manejo por chunks del producto KHRONOS, a costa de mayor uso de VRAM. """ d = self.d_seed kr = self.krank # ── Proyección lineal para el cabezal m ────────────────────────── proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed] zh = F.linear( z.to(dtype=proj_w.dtype, device=proj_w.device), proj_w, ) # [N, d_seed] zh = zh.float() # ── LayerNorm manual por cabezal ───────────────────────────────── norm_w = self.head_norm_weight[m].float() norm_b = self.head_norm_bias[m].float() mean = zh.mean(dim=-1, keepdim=True) var = zh.var(dim=-1, keepdim=True, unbiased=False) zh = (zh - mean) / (var + self.head_norm_eps).sqrt() zh = zh * norm_w + norm_b # ── Sigmoid(x/2) → coordenada latente en [0,1]^d_seed ──────────── zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed] # ── KHRONOS full log-product, sin chunks ───────────────────────── grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots] sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1] dist = (zh.unsqueeze(-1) - grid).abs() * sc # [N, d_seed, n_knots] B = torch.where( dist < 0.5, 0.75 - dist ** 2, torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)), ) # [N, d_seed, n_knots] B = self._normalize_bspline_basis(B) effective_spline = 1.0 + self.head_spline_delta[m].float() phi = torch.einsum( "ndg,dgk->ndk", B, effective_spline, ) # [N, d_seed, krank] log_mag = torch.log(phi.abs() + 1e-9).sum(dim=1) # [N, krank] num_neg = (phi < 0).to(torch.int32).sum(dim=1) # [N, krank] prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank] modes_m = prod_sign * torch.exp(log_mag) # [N, krank] # ── Proyección de salida del cabezal ───────────────────────────── out_m = ( modes_m.to(self.head_out_weight.dtype) @ self.head_out_weight[m] ) # [N, hidden_size] return out_m def _khronos_all_heads( self, B_all: torch.Tensor, ) -> torch.Tensor: """ Vectorized KHRONOS tensor-product for all heads at once. Mathematically identical to calling _khronos_head_product 8 times, but uses a single einsum over the head dimension. The sign-parity aggregation is performed independently per head via the M dimension. Args: B_all: [N, M, d_seed, n_knots] float32 Returns: [N, M, krank] in float32. """ # per_dim: [N, M, d_seed, krank] # einsum: token n, head m, seed-dim d, knot g → krank k # Effective spline coefficient is (1 + head_spline_delta). per_dim = torch.einsum( "nmdg,mdgk->nmdk", B_all, 1.0 + self.head_spline_delta.float(), ) per_dim_abs = per_dim.abs() + 1e-9 # Sum log-magnitudes over d_seed dimension → [N, M, krank] log_mag = torch.log(per_dim_abs).sum(dim=2) num_neg = (per_dim < 0).long().sum(dim=2) prod_sign = 1.0 - 2.0 * (num_neg % 2).float() return prod_sign * torch.exp(log_mag) # [N, M, krank] def _modes_from_basis( self, B_vals: torch.Tensor, spline_coeff: torch.Tensor, target_dtype: torch.dtype, ) -> torch.Tensor: """ Shared tensor-product aggregation used by both the input generator and JTok-M surfaces. phi_{r,j} = spline_coeff[j,r,:] · B_vals[n,r,:] M_j = sign_j * exp(Σ_r log|phi_{r,j}|) (KHRONOS sign-parity) Args: B_vals: [N, d_seed, n_knots] float32 spline_coeff: [num_modes, d_seed, n_knots] target_dtype: output dtype Returns: modes: [N, num_modes] in target_dtype """ phi = torch.einsum( "jrk,nrk->njr", spline_coeff.float(), B_vals, ) # [N, num_modes, d_seed] phi_abs = phi.abs() + 1e-9 log_mag = torch.log(phi_abs).sum(dim=-1) # [N, M] num_neg = (phi < 0).long().sum(dim=-1) # [N, M] prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, M] return (prod_sign * torch.exp(log_mag)).to(target_dtype) # [N, M] def forward( self, token_ids: torch.Tensor, return_internals: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Generate embeddings from discrete token indices. Two parallel paths run from the shared codebook output z: JTok-M path (return_internals): z → seed_proj → seed_norm → sigmoid → z_tilde [N, d_seed] z_tilde → _bspline_basis (fixed scale) → B_vals [N, d_seed, n_knots] These are returned for reuse by every decoder layer's JTok-M surfaces without redundant B-spline evaluation. Generator path (per-head, real Leviathan architecture): For each head i: z → head_proj_i → head_norm_i → sigmoid(x/2) → z_tilde_i z_tilde_i → _bspline_basis_head(scale_i) → B_vals_i B_vals_i × (1 + head_spline_delta_i) → _khronos_head_product → [N, krank] head_out_i([N, krank]) → [N, hidden_size] e = sum of all head_out_i outputs (no W_res, per author's code) Args: token_ids: (batch, seq_len) or (seq_len,) return_internals: if True, also return z_tilde and B_vals for reuse by JTok-M surfaces in every decoder layer. Returns: embeddings [*token_ids.shape, hidden_size], or (embeddings, z_tilde [N, d_seed], B_vals [N, d_seed, n_knots]) when return_internals=True. """ target_dtype = self.codebooks.dtype orig_shape = token_ids.shape N = token_ids.numel() # ── Shared Stage 1: compositional codebook indexing ─────────────── coords = self._base_k_decompose(token_ids) coords_flat = coords.reshape(N, self.k) z = torch.zeros(N, self.d_seed, device=token_ids.device, dtype=target_dtype) for r in range(self.k): z = z + self.codebooks[r][coords_flat[:, r]] # ── Optional JTok-M shared path ───────────────────────────────── # This path is expensive: B_vals has shape [N, d_seed, n_knots]. # It is needed only by JTok-M. During normal training with # use_jtokm=False, do not put this branch in the compiled graph. need_internals = return_internals z_tilde = None B_vals = None if need_internals: seed_h = F.linear( z.to(dtype=self.seed_proj_weight.dtype, device=self.seed_proj_weight.device), self.seed_proj_weight, self.seed_proj_bias, ) z_tilde = torch.sigmoid(self.seed_norm(seed_h)) # [N, d_seed] B_vals = self._bspline_basis(z_tilde.clamp(0.0, 1.0)) # [N, d_seed, n_knots] # ── Per-head generator path, sin chunking sobre d_seed ───────────── # Cada cabezal LEV se evalúa completo: # B [N, d_seed, n_knots] # phi [N, d_seed, krank] # modes [N, krank] # # Esta versión elimina la acumulación por chunks del producto KHRONOS. # Mantiene el loop por cabezal para conservar cabezales independientes, # pero dentro de cada cabezal materializa la forma completa. target_dtype = self.codebooks.dtype e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype) for m in range(self.num_modes): e = e + self._compute_head(z, m) # No W_res — confirmed absent in the authors' implementation e = e.reshape(*orig_shape, self.hidden_size) if return_internals: # return_internals=True is used by JTok-M and must return the # shared Leviathan geometry. The branch above is guaranteed to # have run in this case. return e, z_tilde, B_vals return e # ==================== LEVIATHAN-JTOK-M MODULATION MODULE ==================== class LeviathanJTokM(nn.Module): """ Leviathan-JTok-M token-indexed modulation module for one decoder layer. Fuses the Leviathan continuous geometry with JTok-M (Yang et al., 2026): - Instead of per-token lookup tables (O(V·D) per layer), uses n_e independent CP-separable surfaces over the shared z̃_x from the Leviathan generator, reusing B_vals already computed in the embedding stage. - Context-dependent router: gates over h̃^ℓ_x (hidden state after attention) using Sigmoid+TopK — not Softmax — to avoid inter-surface competition. - Additive injection with 1/√(2ℓ) scaling coordinated with the existing LNS factor 1/√ℓ, maintaining a constant JTok-M / backbone ratio of 1/√2 ≈ 0.707 across all depths (instead of 1/√(2N_l) which would grow JTok-M dominance in deep layers as LNS suppresses backbone activations). - Fully vectorized: all surfaces evaluated in one einsum, TopK with fixed K produces static shapes — compatible with torch.compile max-autotune. Scaling note (LNS coordination): LNS applies 1/√ℓ to backbone sublayer inputs (ℓ = 1-indexed layer). JTok-M applies 1/√(2ℓ) to its injection residual. Ratio: [1/√(2ℓ)] / [1/√ℓ] = 1/√2 — constant at every depth. Parameter cost per layer (defaults n_e=5, M_mod=4, d_seed=128, D=512): spline_coeff: n_e × M_mod × d_seed × n_knots = 5×4×128×32 = 81,920 W_out: n_e × M_mod × D = 5×4×512 = 10,240 W_res: n_e × d_seed × D = 5×128×512 = 327,680 router R: D × n_e = 512×5 = 2,560 scaler s: D = 512 Total per layer: ~422,912 → ~5.07M for 12 layers. References: Yang, Y. et al. (2026). JTok. arXiv:2602.00800. Batley & Saha (2026). Leviathan. arXiv:2601.22040. """ def __init__(self, config: NeoLLMConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.n_e = config.jtokm_num_experts self.top_k = config.jtokm_top_k self.M_mod = config.jtokm_num_modes self.d_seed = config.generator_d_seed self.n_knots = config.generator_num_knots self.hidden_size = config.hidden_size self.norm_eps = config.jtokm_norm_eps # LNS-coordinated scale: 1/√(2ℓ), ℓ = layer_idx + 1 (1-indexed) ell = max(layer_idx + 1, 1) self.lns_scale = 1.0 / math.sqrt(2.0 * ell) # n_e CP-separable surfaces — each with its own spline_coeff, W_out, W_res # Stored as fused tensors for a single vectorized einsum: # spline_coeff: [n_e, M_mod, d_seed, n_knots] # W_out: [n_e, M_mod, D] # W_res: [n_e, d_seed, D] self.spline_coeff = nn.Parameter( torch.empty(self.n_e, self.M_mod, self.d_seed, self.n_knots) ) self.W_out = nn.Parameter( torch.empty(self.n_e, self.M_mod, self.hidden_size) ) self.W_res = nn.Parameter( torch.empty(self.n_e, self.d_seed, self.hidden_size) ) # Context-dependent router: RMSNorm(h̃) @ R → [N, n_e] self.router = nn.Linear(config.hidden_size, self.n_e, bias=False) # Learnable per-dimension scaler (JTok eq. 7 / JTok-M eq. 12) self.scaler = nn.Parameter(torch.ones(config.hidden_size)) # ── Surface evaluation ──────────────────────────────────────────────── def _eval_surfaces( self, B_vals: torch.Tensor, z_tilde: torch.Tensor, target_dtype: torch.dtype, ) -> torch.Tensor: """ Evaluate all n_e surfaces vectorized over the full token batch. phi[n, i, j, r] = spline_coeff[i, j, r, :] · B_vals[n, r, :] M[n, i, j] = sign * exp(Σ_r log|phi[n,i,j,r]|) m[n, i] = W_out[i] @ M[n,i] + W_res[i] @ z̃[n] All shapes are static → torch.compile compatible. Args: B_vals: [N, d_seed, n_knots] float32 z_tilde: [N, d_seed] target_dtype: model dtype Returns: surfaces: [N, n_e, D] """ N = B_vals.shape[0] # phi: [N, n_e, M_mod, d_seed] # einsum: "ijrk, nrk -> nijr" where i=n_e, j=M_mod, r=d_seed, k=n_knots phi = torch.einsum( "ijrk,nrk->nijr", self.spline_coeff.float(), # [n_e, M_mod, d_seed, n_knots] B_vals, # [N, d_seed, n_knots] ) # [N, n_e, M_mod, d_seed] # KHRONOS sign-parity product aggregation over d_seed phi_abs = phi.abs() + 1e-9 log_mag = torch.log(phi_abs).sum(dim=-1) # [N, n_e, M_mod] num_neg = (phi < 0).long().sum(dim=-1) # [N, n_e, M_mod] prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, n_e, M_mod] modes = (prod_sign * torch.exp(log_mag)).to(target_dtype) # modes: [N, n_e, M_mod] # W_out projection: [N, n_e, M_mod] × [n_e, M_mod, D] → [N, n_e, D] out_modes = torch.einsum("nim,imd->nid", modes, self.W_out.to(target_dtype)) # W_res residual: z_tilde [N, d_seed] × [n_e, d_seed, D] → [N, n_e, D] z = z_tilde.to(target_dtype) out_res = torch.einsum("nd,idc->nic", z, self.W_res.to(target_dtype)) surfaces = out_modes + out_res # [N, n_e, D] return surfaces # ── Router ──────────────────────────────────────────────────────────── def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps) def _route_and_mix( self, h_tilde: torch.Tensor, surfaces: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Context-dependent routing over h_tilde (hidden state after attention). Sigmoid + TopK (not Softmax) avoids inter-surface competition: g = RMSNorm(h̃) @ R [N, n_e] K winners selected per token w_i = σ(g_i) / Σ_{j∈K} σ(g_j) (normalized over selected only) e = Σ_{i∈K} w_i * surfaces[n, i] All tensor shapes are static. TopK with fixed K returns [N, K] indices. torch.compile compatible. Also computes load-balancing statistics p_i and f_i for aux loss. Args: h_tilde: [N, D] — hidden state after attention (before MLP) surfaces: [N, n_e, D] Returns: mixed: [N, D] aux_stats: (p_sum [n_e], f_sum [n_e], N) for loss accumulation """ N = h_tilde.shape[0] # Router logits [N, n_e] g = self.router(self._rms_norm(h_tilde)) # TopK selection — static shape [N, top_k] topk_vals, topk_idx = torch.topk(g, self.top_k, dim=-1) # Sigmoid weights over selected K surfaces (JTok-M eq. 10-11) sig_vals = torch.sigmoid(topk_vals) # [N, K] w = sig_vals / sig_vals.sum(dim=-1, keepdim=True) # [N, K] # Gather selected surfaces [N, K, D] and weight-sum # topk_idx: [N, K] → expand to [N, K, D] idx_exp = topk_idx.unsqueeze(-1).expand(N, self.top_k, self.hidden_size) selected = surfaces.gather(dim=1, index=idx_exp) # [N, K, D] mixed = (w.unsqueeze(-1) * selected).sum(dim=1) # [N, D] # Load-balancing statistics for aux loss (Appendix B, Yang et al. 2026) # p_i = mean routing probability over batch # f_i = fraction of tokens actually routed to i with torch.no_grad(): sig_all = torch.sigmoid(g) # [N, n_e] p_sum = sig_all.sum(dim=0) # [n_e] # one-hot mask of selections: [N, n_e] onehot = torch.zeros_like(g).scatter_( 1, topk_idx, 1.0 ) f_sum = onehot.sum(dim=0) # [n_e] return mixed, (p_sum, f_sum, N) # ── Forward ─────────────────────────────────────────────────────────── def forward( self, h_tilde: torch.Tensor, z_tilde: torch.Tensor, B_vals: torch.Tensor, ) -> Tuple[torch.Tensor, Tuple]: """ Compute additive JTok-M residual for one decoder layer. Args: h_tilde: [N, D] hidden state after attention (before MLP) z_tilde: [N, d_seed] latent coordinate from generator B_vals: [N, d_seed, n_k] B-spline basis (computed once, reused) Returns: delta_r: [N, D] additive residual (already scaled) aux_stats: tuple for accumulating load-balance loss """ target_dtype = h_tilde.dtype # All n_e surfaces in one vectorized pass surfaces = self._eval_surfaces(B_vals, z_tilde, target_dtype) # Context-dependent routing mixed, aux_stats = self._route_and_mix(h_tilde, surfaces) # Normalise direction, apply scaler, scale with 1/√(2ℓ) # Norm_ε decouples direction from magnitude (JTok Appendix D.2) mixed_norm = mixed / (mixed.norm(dim=-1, keepdim=True) + self.norm_eps) delta_r = self.lns_scale * self.scaler * mixed_norm # [N, D] return delta_r, aux_stats def compute_jtokm_aux_loss( aux_stats_list: list, n_e: int, weight: float, ) -> torch.Tensor: """ Aggregate load-balancing aux loss over all active layers. L_aux = λ · n_e · Σ_i p_i · f_i averaged over all layers with active JTok-M. Args: aux_stats_list: list of (p_sum [n_e], f_sum [n_e], N) per layer n_e: number of experts weight: λ coefficient Returns: scalar loss tensor """ total_loss = None for p_sum, f_sum, N in aux_stats_list: p_i = p_sum / N # average routing probability [n_e] f_i = f_sum / (N * 1.0) # load fraction [n_e] layer_loss = weight * n_e * (p_i * f_i).sum() total_loss = layer_loss if total_loss is None else total_loss + layer_loss if total_loss is None: return torch.tensor(0.0) return total_loss / len(aux_stats_list) # ==================== ORIGINAL COMPONENTS ==================== class FANLayer(nn.Module): """ Fourier Analysis Network (FAN) layer. FANLayer'(X) = [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)] """ def __init__(self, hidden_size: int, fan_ratio: float = 0.25): super().__init__() self.hidden_size = hidden_size self.fan_ratio = fan_ratio output_dim = hidden_size + int(hidden_size * fan_ratio) self.p_output_dim = int(output_dim * fan_ratio) self.g_output_dim = output_dim - self.p_output_dim * 2 self.input_linear = nn.Linear( hidden_size, self.p_output_dim + self.g_output_dim, bias=True ) self._init_weights() def _init_weights(self): nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02) if self.input_linear.bias is not None: nn.init.zeros_(self.input_linear.bias) def forward( self, x: torch.Tensor, ) -> torch.Tensor: pg = self.input_linear(x) p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1) cos_p = torch.cos(p) sin_p = torch.sin(p) return torch.cat([cos_p, sin_p, g], dim=-1) class StackMemory(nn.Module): """ Differentiable hidden-state stack used by STACKTRANS. This implementation follows the released StackTrans source operator: hidden states are first projected down to a low-rank stack space, split into multiple memory heads, updated by soft push/pop/no-op probabilities, read through query-over-stack attention, projected back to the model dimension, and added residually to the input stream. Shapes: hidden_states: [B, S, D] stack: [B, H, K, d_s/H] mask: [B, H, K] where ``H = num_mem_heads``, ``K = stack_slots``, and ``d_s = stack_d_model``. """ def __init__(self, config: "NeoLLMConfig"): super().__init__() self.config = config self.num_mem_heads = config.num_mem_heads self.stack_slots = config.stack_slots self.head_dim = config.stack_d_model // self.num_mem_heads self.down_proj = nn.Linear(config.hidden_size, config.stack_d_model) self.up_proj = nn.Linear(config.stack_d_model, config.hidden_size) self.action_head = nn.Linear(config.stack_d_model, 3 * self.num_mem_heads) self.gate_proj = nn.Linear(self.head_dim, 1) self.res_weight = nn.Parameter(torch.ones(1)) self.cache_size = getattr(config, "stack_memory_cache_size", 2048) self.cache_position = 0 self.enable_cache = False def reset_cache(self): self.cache_position = 0 def _vectorized_update(self, stack, mask, actions, k_values): batch_size, seq_len = actions.shape[:2] stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1) mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1) push_stack = torch.cat([ k_values.unsqueeze(3), stack[:, :, :, :-1] ], dim=3) push_mask = torch.cat([ torch.ones_like(mask[:, :, :, :1]), mask[:, :, :, :-1] ], dim=3) pop_stack = torch.cat([ stack[:, :, :, 1:], torch.zeros_like(stack[:, :, :, :1]) ], dim=3) pop_mask = torch.cat([ mask[:, :, :, 1:], torch.zeros_like(mask[:, :, :, :1]) ], dim=3) action_weights = actions.unsqueeze(-1).unsqueeze(-1) stacks = torch.stack([push_stack, pop_stack, stack], dim=3) masks = torch.stack([push_mask, pop_mask, mask], dim=3) new_stack = (stacks * action_weights).sum(dim=3) new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3) return new_stack, new_mask def forward(self, hidden_states, stack, mask): batch_size, seq_len, _ = hidden_states.shape new_hidden_states = self.down_proj(hidden_states) action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim) actions = F.softmax( action_logits.view(batch_size, seq_len, self.num_mem_heads, 3), dim=-1 ) k_values = new_hidden_states.view( batch_size, seq_len, self.num_mem_heads, self.head_dim ) new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values) gate_scores = self.gate_proj(new_stack).squeeze(-1) gate_weights = F.softmax(gate_scores + (1 - new_mask) * -80.0, dim=-1) memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3) memory_output = memory_output.view(batch_size, seq_len, -1) memory_output = self.up_proj(memory_output) output = memory_output * self.res_weight + hidden_states if self.training and self.enable_cache: self._update_cache(k_values.detach(), actions.detach()) return output, new_stack[:, -1], new_mask[:, -1] def _update_cache(self, k_values, actions): seq_len = k_values.shape[1] if self.cache_position + seq_len <= self.cache_size: self.k_cache[self.cache_position:self.cache_position + seq_len] = k_values[0] self.action_cache[self.cache_position:self.cache_position + seq_len] = actions[0] self.cache_position += seq_len else: self.reset_cache() def step(self, hidden_state, stack, mask): if not self.enable_cache: return self.forward(hidden_state.unsqueeze(1), stack, mask) if self.cache_position > 0: cached_k = self.k_cache[:self.cache_position] cached_actions = self.action_cache[:self.cache_position] k_values = torch.cat([cached_k.unsqueeze(0), hidden_state], dim=1) actions = torch.cat([ cached_actions.unsqueeze(0), self.action_head(hidden_state).softmax(dim=-1) ], dim=1) else: k_values = hidden_state actions = self.action_head(hidden_state).softmax(dim=-1) new_stack, new_mask = self._vectorized_update( stack.unsqueeze(1), mask.unsqueeze(1), actions.unsqueeze(0), k_values.unsqueeze(0) ) gate_scores = self.gate_proj(new_stack).squeeze(-1) gate_weights = F.softmax(gate_scores + (1 - new_mask) * -80.0, dim=-1) memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3) self._update_cache(k_values, actions) return ( memory_output.squeeze(0) * self.res_weight + hidden_state, new_stack.squeeze(0), new_mask.squeeze(0) ) class LNS(nn.Module): """ LayerNorm Scaling: applies 1/√ℓ to suppress variance growth with depth. From "The Curse of Depth in Large Language Models". """ def __init__(self, layer_idx: int): super().__init__() self.layer_idx = max(layer_idx + 1, 1) self.scale = 1.0 / math.sqrt(self.layer_idx) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.scale class GPAS(nn.Module): """Gradient-Preserving Activation Scaling.""" def __init__(self, d_model: int): super().__init__() self.d_model = d_model self.alpha = nn.Parameter(torch.zeros(1)) def forward( self, x: torch.Tensor, ) -> torch.Tensor: silu_alpha = F.silu(self.alpha) subtracted = silu_alpha * x.detach() return x - subtracted def _make_norm(dim: int, eps: float) -> nn.Module: """Build the active Transformer normalization module. The dynamic normalization path has been removed, so every backbone, Q/K, final, and MEA normalization site uses standard RMSNorm. """ return nn.RMSNorm(dim, eps=eps) def _apply_norm( norm: nn.Module, x: torch.Tensor, ) -> torch.Tensor: """Apply RMSNorm and optionally record the normalized output.""" output = norm(x) return output # ==================== ROTARY EMBEDDING ==================== class NeoLLMRotaryEmbedding(nn.Module): inv_freq: torch.Tensor def __init__(self, config: NeoLLMConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = "default" if (hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict)): rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) if rope_type and rope_type in ROPE_INIT_FUNCTIONS: self.rope_type = rope_type rope_init_fn = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: NeoLLMConfig = None, device: Optional["torch.device"] = None, seq_len: int = None, ) -> tuple["torch.Tensor", float]: base = config.rope_theta dim = getattr(config, "head_dim", None) or \ config.hidden_size // config.num_attention_heads dim = int(dim * getattr(config, "partial_rotary_factor", 1.0)) inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64) .to(device=device, dtype=torch.float) / dim) ) return inv_freq, 1.0 @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): if position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) B = x.shape[0] if position_ids.shape[0] != B: position_ids = position_ids.expand(B, -1) device_type = (x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu") if self.inv_freq.device.type == "meta": inv_freq_data, _ = self.compute_default_rope_parameters( self.config, device=x.device ) self.register_buffer("inv_freq", inv_freq_data, persistent=False) self.register_buffer("original_inv_freq", inv_freq_data.clone(), persistent=False) inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) with torch.autocast(device_type=device_type, enabled=False): freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def linear_clipping(x: torch.Tensor) -> torch.Tensor: """ Piecewise-linear activation for Affine-Scaled Attention scaling factors. Avoids the saturation problem of sigmoid, which collapses most outputs toward 0 or 1 and loses the intermediate scaling range the model needs for fine-grained per-query attention modulation. f(x) = 0 if x ≤ -5 = 0.1·x + 0.5 if -5 < x < 5 = 1 if x ≥ 5 Equivalent to: clamp(0.1·x + 0.5, 0, 1). Output range: [0, 1]. Gradient: 0.1 across the entire non-saturated region. Reference: Bae et al. (2026), Affine-Scaled Attention §6. """ return torch.clamp(0.1 * x + 0.5, 0.0, 1.0) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) return torch.cat([q_embed, q_pass], dim=-1), torch.cat([k_embed, k_pass], dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def causal_first_difference(x: torch.Tensor) -> torch.Tensor: return x - F.pad(x[..., :-1, :], (0, 0, 1, 0)) def rms_key_unit_norm(x: torch.Tensor, eps: float) -> torch.Tensor: return F.normalize(x.float(), p=2, dim=-1, eps=eps) * math.sqrt(x.shape[-1]) def infer_key_validity( attention_mask: Optional[torch.Tensor], seq_len: int, num_heads: int ) -> Optional[torch.Tensor]: if attention_mask is None: return None if attention_mask.ndim == 2: if attention_mask.shape[-1] != seq_len: return None valid = attention_mask.to(dtype=torch.bool).unsqueeze(1) # [B, 1, S] else: if attention_mask.ndim != 4: return None if attention_mask.shape[-2] != seq_len or attention_mask.shape[-1] != seq_len: return None diag = attention_mask.diagonal(dim1=-2, dim2=-1) valid = torch.isfinite(diag) & (diag == 0) if valid.shape[1] == 1 and num_heads != 1: valid = valid.expand(-1, num_heads, -1) elif valid.shape[1] != num_heads: valid = valid[:, :1, :].expand(-1, num_heads, -1) return valid def _expand_pair_valid_mask( attention_mask: Optional[torch.Tensor], batch_size: int, num_heads: int, query_len: int, key_len: int, ) -> Optional[torch.Tensor]: """ Return a boolean [B,H,Sq,Sk] mask for positions that attention may read. Supported contracts used in this file: • 2D padding mask [B,Sk], with truthy values for valid keys. • 4D additive mask [B,1|H,Sq,Sk], with 0 on valid pairs and -inf/large negative values on invalid causal, local or padding pairs. """ if attention_mask is None: return None if attention_mask.ndim == 2: if attention_mask.shape[-1] < key_len: return None valid = attention_mask[:, :key_len].to(dtype=torch.bool) # [B,Sk] valid = valid[:, None, None, :].expand(batch_size, num_heads, query_len, key_len) return valid if attention_mask.ndim != 4: return None if attention_mask.shape[-2] < query_len or attention_mask.shape[-1] < key_len: return None valid = attention_mask[:, :, :query_len, :key_len] == 0 if valid.shape[1] == 1 and num_heads != 1: valid = valid.expand(-1, num_heads, -1, -1) elif valid.shape[1] != num_heads: valid = valid[:, :1, :, :].expand(-1, num_heads, -1, -1) return valid def _masked_value_sum_from_pairs( value_expanded: torch.Tensor, pair_valid: torch.Tensor, ) -> torch.Tensor: """Compute Σ_{j∈A(i)} V_j from an explicit pair-valid mask.""" return torch.matmul(pair_valid.to(value_expanded.dtype), value_expanded) def _causal_value_sum( value_states: torch.Tensor, query_len: int, window_left: Optional[int] = None, ) -> torch.Tensor: """ Compute Σ_{j∈A(i)} V_j for causal or sliding-window causal attention. value_states: [B,H,Sk,d]. The returned tensor is [B,H,Sq,d]. For cached decoding, Sq may be smaller than Sk; query rows are aligned to the last Sq key positions, matching causal self-attention with cache. The sliding-window path uses prefix differences with slices, not padded gathers. This keeps the operation exact while avoiding extra index tensors and repeated index_select calls in the hot path. """ key_len = value_states.shape[2] prefix = value_states.cumsum(dim=2) # [B,H,Sk,d] if window_left is None or window_left < 0 or int(window_left) >= key_len: sums_all = prefix else: left = int(window_left) sums_all = prefix.clone() if left + 1 < key_len: # For i > left: # sum_{j=i-left}^{i} V_j = prefix[i] - prefix[i-left-1]. # The first left+1 rows are already the causal prefixes. sums_all[:, :, left + 1 :, :] = ( prefix[:, :, left + 1 :, :] - prefix[:, :, : -(left + 1), :] ) if query_len == key_len: return sums_all start = max(key_len - query_len, 0) return sums_all[:, :, start:start + query_len, :] def _segmented_causal_value_sum( value_states: torch.Tensor, segment_lengths: torch.Tensor, window_left: Optional[int] = None, ) -> torch.Tensor: """ Compute segmented causal/window value sums without a Python loop over documents. Used only for packed IHA where B=1 and position_ids reset. value_states: [1,H,Sk,d] segment_lengths: [num_segments] lengths in the same expanded sequence axis return: [1,H,Sk,d] """ key_len = value_states.shape[2] if key_len == 0: return value_states lengths = segment_lengths.to(device=value_states.device, dtype=torch.long) seg_ends = lengths.cumsum(0) seg_starts = seg_ends - lengths seg_start_per_pos = torch.repeat_interleave(seg_starts, lengths) pos = torch.arange(key_len, device=value_states.device, dtype=torch.long) if window_left is None or window_left < 0: start_idx = seg_start_per_pos else: start_idx = torch.maximum(seg_start_per_pos, pos - int(window_left)) end_idx = pos + 1 prefix = F.pad(value_states.cumsum(dim=2), (0, 0, 1, 0)) # [1,H,Sk+1,d] end_vals = prefix.index_select(2, end_idx) start_vals = prefix.index_select(2, start_idx) return end_vals - start_vals def _affine_valid_value_sum( value_expanded: torch.Tensor, attention_mask: Optional[torch.Tensor], query_len: int, sliding_window: Optional[Union[int, Tuple[int, int]]] = None, ) -> torch.Tensor: """ Compute the β term support S_i = Σ_{j∈A(i)} V_j for Affine attention. It first respects an explicit 4D additive mask exactly. Otherwise it uses causal prefix sums, optionally restricted by a left sliding-window size. """ batch_size, num_heads, key_len, _ = value_expanded.shape pair_valid = None if attention_mask is not None and attention_mask.ndim == 4: pair_valid = _expand_pair_valid_mask( attention_mask, batch_size, num_heads, query_len, key_len ) if pair_valid is not None: return _masked_value_sum_from_pairs(value_expanded, pair_valid) # 2D padding masks zero invalid keys before the prefix/window sum. if attention_mask is not None and attention_mask.ndim == 2: if attention_mask.shape[-1] >= key_len: valid = attention_mask[:, :key_len].to(value_expanded.dtype) value_expanded = value_expanded * valid[:, None, :, None] window_left = None if isinstance(sliding_window, tuple): window_left = int(sliding_window[0]) if len(sliding_window) > 0 else None elif sliding_window is not None: window_left = int(sliding_window) return _causal_value_sum(value_expanded, query_len, window_left) def head_linear_compose( hidden_states: torch.Tensor, mixing_matrix: torch.Tensor ) -> torch.Tensor: return torch.einsum( "bhtd,hk->bktd", hidden_states, mixing_matrix.to(device=hidden_states.device, dtype=hidden_states.dtype), ) def head_linear_compose_pseudo( hidden_states: torch.Tensor, mixing_matrix: torch.Tensor, num_pseudo_heads: int, ) -> torch.Tensor: """ Applies the same MEA head mixing independently inside each IHA pseudo-slot. Layout convention: hidden_states: [B, H_in * P, S, d] with flattened order (head, pseudo) mixing_matrix: [H_in, H_out] returns: [B, H_out * P, S, d] """ if num_pseudo_heads <= 1: return head_linear_compose(hidden_states, mixing_matrix) batch, total_heads, slen, head_dim = hidden_states.shape if total_heads % num_pseudo_heads != 0: raise ValueError( f"IHA+MEA expected total_heads divisible by P, got " f"{total_heads} vs P={num_pseudo_heads}." ) num_component_heads = total_heads // num_pseudo_heads if mixing_matrix.shape[0] != num_component_heads: raise ValueError( f"IHA+MEA expected {num_component_heads} component heads for MEA, " f"got mixing_matrix.shape[0]={mixing_matrix.shape[0]}." ) hidden_states = hidden_states.reshape( batch, num_component_heads, num_pseudo_heads, slen, head_dim ) mixed = torch.einsum( "bhpsd,hk->bkpsd", hidden_states, mixing_matrix.to(device=hidden_states.device, dtype=hidden_states.dtype), ) return mixed.reshape(batch, mixing_matrix.shape[1] * num_pseudo_heads, slen, head_dim) class MEAHeadRMSNorm(nn.Module): """MEA head-level RMS normalization grouped by KV structure (GQA-aware).""" def __init__( self, num_heads: int, head_dim: int, num_kv_groups: int, eps: float = 1e-6, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.num_kv_groups = num_kv_groups self.num_kv_heads = num_heads // num_kv_groups self.group_dim = num_kv_groups * head_dim self.norm = _make_norm(self.group_dim, eps=eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch, seq_len, num_heads, head_dim = hidden_states.shape if num_heads != self.num_heads or head_dim != self.head_dim: raise ValueError( f"MEAHeadRMSNorm expected ({self.num_heads}, {self.head_dim}), " f"received ({num_heads}, {head_dim})" ) grouped = hidden_states.reshape(batch, seq_len, self.num_kv_heads, self.group_dim) return self.norm(grouped).reshape(batch, seq_len, num_heads, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous() attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training) return attn_output, attn_weights def affine_scaled_eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, alpha: torch.Tensor, beta: torch.Tensor, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): """ Affine-Scaled Attention (eager path). Replaces the standard weighted sum softmax(QK^T/√dk) V with: [α(X) · softmax(QK^T/√dk) + β(X)] V α is a per-head, per-query input-dependent scale in [0, 1]. β is an input-dependent bias that compensates for deviations of α from its running average, preventing the effective attention mass from collapsing. Both α and β are computed in NeoLLMAttention.forward and passed in; this function only performs the affine reweighting and the value aggregation. The existing Gated Attention gate (applied post-SDPA to the concatenated output before o_proj) is orthogonal to this and is not modified here. Reference: Bae et al. (2026), Affine-Scaled Attention, Eq. 6–8. Args: alpha: [batch, num_heads, seq_q, 1] — input-dependent scale per query beta: [batch, num_heads, seq_q, 1] — input-dependent bias per query """ key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling pair_valid = _expand_pair_valid_mask( attention_mask, batch_size=query.shape[0], num_heads=key_states.shape[1], query_len=query.shape[-2], key_len=key_states.shape[-2], ) if attention_mask is not None: if attention_mask.ndim == 4: attn_weights = attn_weights + attention_mask[:, :, : query.shape[-2], : key_states.shape[-2]] elif attention_mask.ndim == 2: key_valid = attention_mask[:, : key_states.shape[-2]].to(dtype=torch.bool) additive = torch.zeros_like(attn_weights) additive = additive.masked_fill(~key_valid[:, None, None, :], torch.finfo(attn_weights.dtype).min) attn_weights = attn_weights + additive attn_weights_softmax = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query.dtype) if pair_valid is not None: attn_weights_softmax = attn_weights_softmax.masked_fill(~pair_valid, 0.0) # Affine reweighting over valid keys only. β must obey the same causal, # local and padding mask as softmax; otherwise invalid V_j would leak in. # Shapes: α, β are [B, H, S_q, 1], weights are [B, H, S_q, S_k]. attn_weights_affine = alpha * attn_weights_softmax + beta if pair_valid is not None: attn_weights_affine = attn_weights_affine.masked_fill(~pair_valid, 0.0) attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous() attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training) return attn_output, attn_weights_affine def affine_scaled_flash_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, alpha: torch.Tensor, beta: torch.Tensor, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): """ Affine-Scaled Attention — flash/sdpa path. Mathematical decomposition of [α·softmax(QKᵀ)+β]V using only the public flash/sdpa interface — no kernel modification required. Derivation ---------- The paper formula expands distributively: [α · softmax(QKᵀ/√dk) + β] · V = α · [softmax(QKᵀ/√dk) · V] ← term 1: standard backend output + β · [Σ_{j∈A(i)} V_j] ← term 2: same valid keys A(i) For global causal attention, A(i) is the prefix j≤i and term 2 is a cumsum. For sliding-window or packed/additive-mask paths, term 2 must use the corresponding windowed or explicit valid-key sum. Dropout ------- The eager path drops entries of the combined weight matrix (α·softmax + β) before multiplying by V. With the flash interface we cannot access that combined matrix, so we apply dropout=0 to the flash kernel and instead apply nn.functional.dropout to the final combined output tensor. This is output dropout rather than weight dropout — a different (but standard) regularisation that achieves the same intent without the intermediate weight matrix. During inference dropout=0 so the paths are identical. Mask handling ------------- The β value-sum must respect the same valid keys as the backend attention. A 2D padding mask zeros invalid keys before prefix/window sums. A 4D additive mask is used directly to compute Σ_{j∈A(i)} V_j exactly. Memory overhead vs standard flash call --------------------------------------- Global causal attention adds one [B, H_q, S, d_head] value-sum tensor. Exact 4D-mask fallback may materialise a [B, H_q, S_q, S_k] boolean mask. Args: alpha: [B, H_q, S_q, 1] — input-dependent scale per query, in [0, 1] beta: [B, H_q, S_q, 1] — moving-average bias per query """ # ── Term 1: standard flash / sdpa output ───────────────────────────── # dropout=0.0: we apply dropout to the combined output below instead. attn_fn = ALL_ATTENTION_FUNCTIONS[module.config._attn_implementation] flash_out, _ = attn_fn( module, query, key, value, attention_mask, dropout=0.0, scaling=scaling, **kwargs, ) # flash_out: [B, S, H_q, d_head] — HF wrappers all return this layout # ── Term 2: β · Σ_{j∈A(i)} V_j ────────────────────────────────────── # Fast path: compute the valid-key value sum in KV-head space, then # broadcast it to query heads during the final affine combine. Since the # sum is linear, WindowSum(repeat_kv(V)) == repeat_kv(WindowSum(V)); this # avoids doing cumsum/window arithmetic on the expanded H_q heads. alpha_t = alpha.permute(0, 2, 1, 3) # [B, S_q, H_q, 1] beta_t = beta.permute(0, 2, 1, 3) # [B, S_q, H_q, 1] if attention_mask is not None and attention_mask.ndim == 4: # Arbitrary 4D additive masks may be expressed in query-head space, so # keep the exact fallback on repeated V. This path is not the regular # causal/window hot path. value_expanded = repeat_kv(value, module.num_key_value_groups) value_sum = _affine_valid_value_sum( value_expanded, attention_mask=attention_mask, query_len=query.shape[-2], sliding_window=kwargs.get("sliding_window", None), ) v_sum_t = value_sum.transpose(1, 2) # [B,S_q,H_q,d] output = alpha_t * flash_out + beta_t * v_sum_t else: # Keep the Affine β branch as a causal prefix summary even when the # attention backend itself uses a local sliding window. This restores # the original IHA/local behavior: local logits plus a cheap global # causal value context. Padding masks are still handled inside # _affine_valid_value_sum; explicit 4D masks above remain exact. value_sum_kv = _affine_valid_value_sum( value, attention_mask=attention_mask, query_len=query.shape[-2], sliding_window=None, ) # [B,H_kv,S_q,d] v_sum_t = value_sum_kv.transpose(1, 2) # [B,S_q,H_kv,d] H_q = flash_out.shape[2] H_kv = v_sum_t.shape[2] groups = module.num_key_value_groups if groups == 1 or H_q == H_kv: output = alpha_t * flash_out + beta_t * v_sum_t else: Bsz, Sq, _, Dh = flash_out.shape flash_g = flash_out.reshape(Bsz, Sq, H_kv, groups, Dh) alpha_g = alpha_t.reshape(Bsz, Sq, H_kv, groups, 1) beta_g = beta_t.reshape(Bsz, Sq, H_kv, groups, 1) output = (alpha_g * flash_g + beta_g * v_sum_t.unsqueeze(3)).reshape( Bsz, Sq, H_q, Dh ) # ── Combine and apply dropout to the full affine output ─────────────── output = nn.functional.dropout(output, p=dropout, training=module.training) # attn_weights is None — flash never exposes the softmax weight matrix. return output, None class HadamardOProj(nn.Module): """ Parameter-free Walsh–Hadamard output projection with learnable affine rescaling. Replaces the dense W_O ∈ R^{d×d} in multi-head attention with a fixed orthogonal Walsh–Hadamard Transform followed by a per-channel learnable affine: output = α ⊙ FWHT(x) + β Motivation (Aggarwal & Kumar, 2026, arXiv:2603.08343): The standard dense o_proj develops extreme condition numbers during training (κ up to 10^5 observed in practice) because the optimiser has no incentive to keep singular values balanced — some directions are amplified while others collapse toward zero. This makes the layer hostile to FP8 quantisation, which uses a single per-tensor scale and therefore loses the low-magnitude directions entirely. The Walsh–Hadamard Transform is a fixed orthogonal matrix whose singular values are all identically 1, making κ = 1 by construction. It cannot develop condition-number pathology because it has no parameters. The learnable α/β restore per-channel expressivity at a cost of 2·d parameters instead of d². Properties: - Condition number: κ = 1 (exact, permanent, by construction) - Parameters: 2·d vs d² for dense (~25% attention params saved) - Forward FLOPs: O(d log d) vs O(d²) for dense - Norm preservation: FWHT is isometric — ‖FWHT(x)‖₂ = ‖x‖₂ - FP8 friendliness: single per-tensor scale covers all directions equally - Requires: d must be a power of 2 The FWHT is implemented as an in-place iterative butterfly (Cooley-Tukey pattern over additions/subtractions) followed by 1/√d normalisation to produce an orthonormal transform (H^T H = I). No external dependency. Reference: Aggarwal, S. & Kumar, L. (2026). "Rethinking Attention Output Projection: Structured Hadamard Transforms for Efficient Transformers." arXiv:2603.08343. """ def __init__(self, dim: int, bias: bool = True): super().__init__() assert dim > 0 and (dim & (dim - 1)) == 0, ( f"HadamardOProj requires dim to be a power of 2, got {dim}" ) self.dim = dim self.norm = dim ** -0.5 # 1/√d — makes H^T H = I # Learnable affine rescaling: α ⊙ FWHT(x) + β # Initialised to α=1, β=0 so the layer starts as a pure WHT, # identical to an orthonormal projection with unit gain. self.alpha = nn.Parameter(torch.ones(dim)) self.beta = nn.Parameter(torch.zeros(dim)) if bias else None def _fwht(self, x: torch.Tensor) -> torch.Tensor: """ Iterative in-place Fast Walsh–Hadamard Transform over the last dim. Butterfly pattern: log₂(d) stages, each pairing elements at stride h. Cost: d·log₂(d) additions/subtractions, zero multiplications. Compatible with torch.compile — all shapes are static, no Python loops visible to the tracer once d is fixed. """ h = 1 while h < self.dim: # Reshape to expose pairs at current stride x = x.reshape(*x.shape[:-1], -1, 2 * h) a, b = x[..., :h], x[..., h:] # Butterfly: (a+b, a-b) — only additions and subtractions x = torch.cat([a + b, a - b], dim=-1) x = x.reshape(*x.shape[:-2], self.dim) h *= 2 return x def forward( self, x: torch.Tensor, ) -> torch.Tensor: """ Args: x: [..., dim] — concatenated multi-head attention outputs Returns: α ⊙ (FWHT(x) / √dim) + β of shape [..., dim] """ out = self._fwht(x) * self.norm # normalise: H^T H = I out = out * self.alpha # per-channel learnable scale if self.beta is not None: out = out + self.beta # per-channel learnable bias return out class REPOModule(nn.Module): """ Context Re-Positioning module f_ϕ (Li et al., 2026, arXiv:2512.14391). Replaces the fixed linear integer indices ``0…L-1`` fed to RoPE with continuous, data-dependent positions ``z_i`` learned end-to-end. Architecture (Eq. 4–6 of the paper): # Position representation — shared across all heads in this layer r_i = Swish(h_i W_g) ⊙ (h_i W_c) r_i ∈ R^{d_p} # Position assignment — independent per head z_i^(h) = r_i w_z^(h) z_i^(h) ∈ R (scalar) where ``h_i ∈ R^d`` is the hidden state of token ``i`` entering the decoder layer (pre-FANLayer), and ``d_p = hidden_size // 8`` by default. The resulting positions ``z [B, H, S]`` are real-valued and unconstrained. They are used to compute per-head ``cos/sin`` embeddings inline from ``inv_freq``, replacing the standard integer-based ``position_embeddings`` for this layer. Design notes: - ``W_g`` and ``W_c`` are shared across heads (parameter efficiency). - ``W_z`` is a single ``[d_p, num_heads]`` matrix; each column is the per-head assignment vector ``w_z^(h)``. Vectorized as one matmul. - The raw hidden state ``h_i`` (not the FAN-augmented or normed variant) is used as input, matching the paper's formulation and avoiding circular dependency with q/k norm. - No bias on any projection — consistent with the paper's Eq. 4–5. Reference: Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). "REPO: Language Models with Context Re-Positioning." arXiv:2512.14391. """ def __init__(self, hidden_size: int, d_p: int, num_heads: int): super().__init__() self.hidden_size = hidden_size self.d_p = d_p self.num_heads = num_heads # SwiGLU position representation (shared across heads, Eq. 4) self.W_g = nn.Linear(hidden_size, d_p, bias=False) self.W_c = nn.Linear(hidden_size, d_p, bias=False) # Per-head position assignment (vectorized, Eq. 5) # W_z[:, h] is w_z^(h) for head h self.W_z = nn.Linear(d_p, num_heads, bias=False) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: [B, S, hidden_size] — residual stream entering the decoder layer, before FANLayer augmentation. Returns: z: [B, H, S] — continuous per-head position scalars. z[:, h, i] is the position assigned to token i by head h. """ # Position representation (Eq. 4): Swish(h W_g) ⊙ (h W_c) r = F.silu(self.W_g(hidden_states)) * self.W_c(hidden_states) # [B, S, d_p] # Per-head assignment (Eq. 5): z^(h) = r W_z[:, h] # W_z output: [B, S, H] → transpose to [B, H, S] z = self.W_z(r).transpose(1, 2).contiguous() # [B, H, S] return z def _apply_repo_rope( q: torch.Tensor, k: torch.Tensor, z: torch.Tensor, inv_freq: torch.Tensor, attention_scaling: float, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply RoPE to Q and K using continuous per-head positions from REPO. Replaces the standard ``apply_rotary_pos_emb(q, k, cos, sin)`` call for layers where REPO is active. Builds ``cos/sin`` inline from ``z`` and ``inv_freq`` so that the rotation is differentiable w.r.t. ``z`` and therefore w.r.t. the parameters of REPOModule. Args: q: [B, H, S, head_dim] k: [B, H_kv, S, head_dim] (GQA: H_kv ≤ H) z: [B, H_repo, S] — per-head positions from REPOModule. Under IHA(P>1), pseudo-heads inherit the continuous position of their parent query head before RoPE. inv_freq: [rotary_dim/2] — frozen RoPE frequency vector attention_scaling: float — scaling factor from NeoLLMRotaryEmbedding Returns: (q_embed, k_embed) with the same shapes as (q, k). Implementation note on GQA: Q has ``num_attention_heads`` heads; K/V have ``num_key_value_heads`` heads (fewer under GQA). REPO produces one position per Q head. For K we average the positions of the Q heads that map to each KV head (groups of size ``num_key_value_groups``). This is the minimal approach consistent with the paper's per-head independence claim: each KV head receives a position that is representative of the Q heads it serves. """ B, H_repo, S = z.shape H_q_eff = q.shape[1] H_k_eff = k.shape[1] if H_q_eff % H_repo != 0: raise ValueError( f"REPO expected q heads divisible by z heads, got " f"H_q={H_q_eff} vs H_repo={H_repo}." ) P = H_q_eff // H_repo if H_k_eff % max(P, 1) != 0: raise ValueError( f"REPO expected k heads divisible by pseudo factor P, got " f"H_k={H_k_eff} vs P={P}." ) # Keep the logical factorization (original_query_head, pseudo_slot) # explicit so both Q and K receive positions in the same IHA pseudo-slot. z_q_struct = z.unsqueeze(2).expand(B, H_repo, P, S) # [B, H_q0, P, S] H_k_base = H_k_eff // max(P, 1) if H_repo % H_k_base != 0: raise ValueError( f"REPO expected original query heads divisible by key head groups, got " f"H_repo={H_repo} vs H_k_base={H_k_base}." ) q_per_k_base = H_repo // H_k_base z_q = z_q_struct.reshape(B, H_q_eff, S) # [B, H_q_eff, S] z_k_struct = z_q_struct.view(B, H_k_base, q_per_k_base, P, S).mean(dim=2) z_k_heads = z_k_struct.reshape(B, H_k_eff, S) # [B, H_k_eff, S] H_q = z_q.shape[1] H_kv = z_k_heads.shape[1] if H_q % H_kv != 0: raise ValueError( f"REPO expected query heads divisible by key heads, got H_q={H_q} vs H_kv={H_kv}." ) rotary_dim = inv_freq.shape[0] * 2 # inv_freq covers half the rotary dim # inv_freq arrives from rotary_emb at forward time via repo_rope_args — # already float32 on the correct device, no .to() needed, no DeviceCopy op. # No autocast barrier: explicit .float() casts on z_q/z_k are sufficient # to maintain float32 precision for the trig ops. Removing the context # manager lets Inductor plan all intermediate tensors as part of a single # static memory graph, eliminating mid-forward allocations that cause # VRAM variance under max-autotune. inv_freq_f = inv_freq # z_q: [B, H_q, S, 1] × inv_freq: [rotary_dim/2] → [B, H_q, S, rotary_dim/2] z_q = z_q.float().unsqueeze(-1) # [B, H_q, S, 1] freqs_q = z_q * inv_freq_f # [B, H_q, S, r/2] emb_q = torch.cat([freqs_q, freqs_q], dim=-1) # [B, H, S, r] cos_q = (emb_q.cos() * attention_scaling).to(q.dtype) sin_q = (emb_q.sin() * attention_scaling).to(q.dtype) # KV positions: average the original query heads that map to each key head, # independently inside each IHA pseudo-slot. z_k = z_k_heads.float().unsqueeze(-1) # [B, H_kv, S, 1] freqs_k = z_k * inv_freq_f # [B, H_kv, S, r/2] emb_k = torch.cat([freqs_k, freqs_k], dim=-1) # [B, H_kv, S, r] cos_k = (emb_k.cos() * attention_scaling).to(k.dtype) sin_k = (emb_k.sin() * attention_scaling).to(k.dtype) # Rotate only the first rotary_dim channels; pass the rest through unchanged. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = torch.cat( [(q_rot * cos_q) + (rotate_half(q_rot) * sin_q), q_pass], dim=-1 ) k_embed = torch.cat( [(k_rot * cos_k) + (rotate_half(k_rot) * sin_k), k_pass], dim=-1 ) return q_embed, k_embed class RepoGrapePositioning(nn.Module): r""" Minimal REPO-GRAPE-M positional operator. This module intentionally keeps only the specific operator from the proposal: 1. REPO predicts contextual coordinates z_i^(h)=f_phi^(h)(h_i). 2. REPO-GRAPE uses u_i^(h)=z_i^(h), with no extra position mode. 3. GRAPE-M applies canonical commuting SO(d) rotary planes to Q/K. 4. Each query head learns its own angular spectrum: theta_{h,r}=inv_freq_r*exp(s_{h,r}). Therefore (G_h(u_i)q_i)^T(G_h(u_j)k_j) = q_i^T G_h(u_j-u_i) k_j, preserving the relative GRAPE law while staying compatible with the existing attention backend. """ def __init__( self, config: NeoLLMConfig, layer_idx: int, num_attention_heads: int, num_key_value_heads: int, head_dim: int, ): super().__init__() self.config = config self.layer_idx = layer_idx self.num_attention_heads = int(num_attention_heads) self.num_key_value_heads = int(num_key_value_heads) self.head_dim = int(head_dim) self.max_rot_half = self.head_dim // 2 # The single REPO-GRAPE parameter: a learned log-scale over the # base RoPE/GRAPE frequencies, per query head and rotary plane. self.freq_log_scale = nn.Parameter( torch.zeros(self.num_attention_heads, self.max_rot_half, dtype=torch.float32), requires_grad=True, ) def _apply(self, fn): # Keep the small spectral parameters in fp32 under model.to(dtype). super()._apply(fn) p = self.freq_log_scale p.data = p.data.float() if p.grad is not None: p.grad.data = p.grad.data.float() return self def transform_positions( self, z: torch.Tensor, ) -> torch.Tensor: """Use the REPO coordinate directly: u_i^(h)=z_i^(h).""" return z def _query_freq( self, inv_freq: torch.Tensor, H_repo: int, ) -> torch.Tensor: rot_half = int(inv_freq.shape[0]) if rot_half > self.max_rot_half: raise ValueError( f"REPO-GRAPE rotary_dim/2={rot_half} exceeds head_dim/2={self.max_rot_half}." ) if H_repo != self.num_attention_heads: raise ValueError( f"REPO-GRAPE expected H_repo={self.num_attention_heads}, got {H_repo}." ) base = inv_freq.float().view(1, rot_half) log_scale = self.freq_log_scale[:, :rot_half].float() freq = base * torch.exp(log_scale) freq = freq.to(device=inv_freq.device) return freq def apply_multiplicative( self, q: torch.Tensor, k: torch.Tensor, z: torch.Tensor, inv_freq: torch.Tensor, attention_scaling: float, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply REPO-GRAPE-M rotation to Q/K using contextual REPO coordinates. Args: q: [B, H_q_eff, S, head_dim] k: [B, H_k_eff, S, head_dim] z: [B, H_q_base, S] inv_freq: [rotary_dim/2] attention_scaling: scaling from NeoLLMRotaryEmbedding """ B, H_repo, S = z.shape H_q_eff = q.shape[1] H_k_eff = k.shape[1] if H_q_eff % H_repo != 0: raise ValueError( f"REPO-GRAPE expected q heads divisible by z heads, " f"got H_q={H_q_eff} vs H_repo={H_repo}." ) P = H_q_eff // H_repo if H_k_eff % max(P, 1) != 0: raise ValueError( f"REPO-GRAPE expected k heads divisible by pseudo factor P, " f"got H_k={H_k_eff} vs P={P}." ) z_q_struct = z.unsqueeze(2).expand(B, H_repo, P, S) # [B,H_q_base,P,S] H_k_base = H_k_eff // max(P, 1) if H_repo % H_k_base != 0: raise ValueError( f"REPO-GRAPE expected original query heads divisible by key groups, " f"got H_repo={H_repo} vs H_k_base={H_k_base}." ) q_per_k_base = H_repo // H_k_base z_q = z_q_struct.reshape(B, H_q_eff, S) z_k = ( z_q_struct .view(B, H_k_base, q_per_k_base, P, S) .mean(dim=2) .reshape(B, H_k_eff, S) ) rot_half = int(inv_freq.shape[0]) rotary_dim = rot_half * 2 freq_repo = self._query_freq(inv_freq, H_repo) freq_q_struct = freq_repo.unsqueeze(1).expand(H_repo, P, rot_half) freq_q = freq_q_struct.reshape(H_q_eff, rot_half) freq_k = ( freq_q_struct .view(H_k_base, q_per_k_base, P, rot_half) .mean(dim=1) .reshape(H_k_eff, rot_half) ) freqs_q = z_q.float().unsqueeze(-1) * freq_q.to(device=q.device).view( 1, H_q_eff, 1, rot_half ) freqs_k = z_k.float().unsqueeze(-1) * freq_k.to(device=k.device).view( 1, H_k_eff, 1, rot_half ) emb_q = torch.cat([freqs_q, freqs_q], dim=-1) emb_k = torch.cat([freqs_k, freqs_k], dim=-1) cos_q = (emb_q.cos() * attention_scaling).to(q.dtype) sin_q = (emb_q.sin() * attention_scaling).to(q.dtype) cos_k = (emb_k.cos() * attention_scaling).to(k.dtype) sin_k = (emb_k.sin() * attention_scaling).to(k.dtype) q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = torch.cat( [(q_rot * cos_q) + (rotate_half(q_rot) * sin_q), q_pass], dim=-1, ) k_embed = torch.cat( [(k_rot * cos_k) + (rotate_half(k_rot) * sin_k), k_pass], dim=-1, ) return q_embed, k_embed class RepoGoatPrior(nn.Module): r""" Factorised GOAT-style attention log-prior for NeoLLM. GOAT interprets an additive attention bias as a log-prior inside the row-wise KL/EOT objective: p_i = softmax(s_i / tau + log pi_i). This module keeps that term separate from the REPO-GRAPE-M geometric score. Instead of materialising K_ij as a dense [S,S] matrix, it appends a small positional subspace to Q and K: [Q | Q_prior] [K | K_prior]^T / sqrt(d_head) = QK^T / sqrt(d_head) + K_prior_logits. V is padded with zeros in the appended dimensions and the attention output is sliced back to head_dim by the caller, so downstream modules see exactly the same shape whether the prior is enabled or disabled. Components ---------- 1. Recency key-only prior: q_r = lambda_h, k_r = c_j This is equivalent to lambda_h(c_j - c_i) modulo a per-row constant. 2. Relative Fourier prior: alpha_h,r cos(omega_r(c_i-c_j)) + beta_h,r sin(omega_r(c_j-c_i)) represented by two separable coordinates per frequency. 3. Sink key-only prior: q_s = rho_h, k_s = exp(-softplus(delta_h) c_j) This gives an explicit, controllable sink profile instead of forcing the model to create sinks indirectly through content norms. Initialisation is an exact no-op: all amplitudes are zero, while the key-side bases remain non-zero so gradients immediately reach the amplitudes. Parameters are kept in fp32 under model.to(dtype), matching the REPO-GRAPE spectral parameters. """ def __init__( self, config: NeoLLMConfig, layer_idx: int, num_attention_heads: int, head_dim: int, ): super().__init__() self.config = config self.layer_idx = layer_idx self.num_attention_heads = int(num_attention_heads) self.head_dim = int(head_dim) self.num_frequencies = int(getattr(config, "repo_goat_num_frequencies", 3)) self.prior_dim = 2 * self.num_frequencies + 2 self.sqrt_head_dim = math.sqrt(float(self.head_dim)) # Amplitudes live on the query side so each query head can learn a # different prior even under GQA, where K/V heads are shared by groups. self.recency_slope = nn.Parameter(torch.zeros(self.num_attention_heads, dtype=torch.float32)) self.rel_alpha = nn.Parameter( torch.zeros(self.num_attention_heads, self.num_frequencies, dtype=torch.float32) ) self.rel_beta = nn.Parameter( torch.zeros(self.num_attention_heads, self.num_frequencies, dtype=torch.float32) ) self.sink_strength = nn.Parameter(torch.zeros(self.num_attention_heads, dtype=torch.float32)) # Learnable positive sink decay, initialized to the configured value. sink_decay = float(getattr(config, "repo_goat_sink_decay", 4.0)) inv_softplus = math.log(math.exp(sink_decay) - 1.0) self.sink_log_decay = nn.Parameter( torch.full((self.num_attention_heads,), inv_softplus, dtype=torch.float32) ) if self.num_frequencies > 0: base_freq = 2.0 * math.pi * torch.arange( 1, self.num_frequencies + 1, dtype=torch.float32 ) else: base_freq = torch.empty(0, dtype=torch.float32) self.register_buffer("prior_freq", base_freq, persistent=False) def _apply(self, fn): # Keep small prior amplitudes in fp32 under model.to(dtype). super()._apply(fn) for p in ( self.recency_slope, self.rel_alpha, self.rel_beta, self.sink_strength, self.sink_log_decay, ): p.data = p.data.float() if p.grad is not None: p.grad.data = p.grad.data.float() self.prior_freq = self.prior_freq.float() return self def _normalised_positions( self, seq_len: int, batch_size: int, device: torch.device, position_ids: Optional[torch.Tensor], ) -> torch.Tensor: """ Return a differentiable-free causal coordinate c in [0,1]. If position_ids are supplied, they are respected, including the IHA seq-expanded case where the attention sequence length is S*P but position_ids still describe the original S tokens. When no compatible position_ids exist, fall back to arange(seq_len). """ pos = None if position_ids is not None: pid = position_ids if pid.dim() == 1: pid = pid.unsqueeze(0) elif pid.dim() > 2: pid = pid[..., 0] base_len = pid.shape[-1] if base_len == seq_len: pos = pid elif base_len > 0 and seq_len % base_len == 0: P = seq_len // base_len offsets = torch.arange(P, device=pid.device, dtype=pid.dtype) pos = (pid.unsqueeze(-1) * P + offsets.view(1, 1, P)).reshape( pid.shape[0], seq_len ) if pos is None: pos = torch.arange(seq_len, device=device, dtype=torch.float32).view(1, seq_len) else: pos = pos.to(device=device, dtype=torch.float32) if pos.shape[0] == 1 and batch_size != 1: pos = pos.expand(batch_size, -1) elif pos.shape[0] != batch_size: pos = pos[:1].expand(batch_size, -1) pos = pos - pos.amin(dim=-1, keepdim=True) denom = pos.amax(dim=-1, keepdim=True).clamp_min(1.0) return (pos / denom).clamp(0.0, 1.0) def append_prior_subspace( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: """ Append prior dimensions to q/k/v and return the appended width. Args: q: [B, H_q, S_q, head_dim] k: [B, H_kv, S_k, head_dim] v: [B, H_kv, S_k, head_dim] Returns: q_ext, k_ext, v_ext, prior_dim """ B, H_q, S_q, _ = q.shape H_kv, S_k = k.shape[1], k.shape[2] if H_q != self.num_attention_heads: raise ValueError( f"REPO-GOAT expected H_q={self.num_attention_heads}, got {H_q}." ) c_q = self._normalised_positions(S_q, B, q.device, position_ids) c_k = self._normalised_positions(S_k, B, k.device, position_ids) q_prior = q.new_zeros(B, H_q, S_q, self.prior_dim) k_prior = k.new_zeros(B, H_kv, S_k, self.prior_dim) c_q_f = c_q.float().unsqueeze(1) # [B,1,S_q] c_k_f = c_k.float().unsqueeze(1) # [B,1,S_k] scale = self.sqrt_head_dim # 0: recency. softmax ignores the missing row-constant -lambda_h c_i, # so lambda_h*c_j is distributionally equivalent to lambda_h(c_j-c_i). q_prior[..., 0] = ( self.recency_slope.float().view(1, H_q, 1) * scale ).to(q.dtype) k_prior[..., 0] = c_k_f.expand(B, H_kv, S_k).to(k.dtype) # 1..2R: relative Fourier kernel. Coefficients are query-head specific; # key-side features are shared over KV heads, preserving GQA semantics. if self.num_frequencies > 0: omega = self.prior_freq.to(device=q.device, dtype=torch.float32).view(1, 1, 1, -1) angle_q = c_q_f.unsqueeze(-1) * omega # [B,1,S_q,R] angle_k = c_k_f.unsqueeze(-1) * omega # [B,1,S_k,R] cos_q = angle_q.cos() sin_q = angle_q.sin() cos_k = angle_k.cos() sin_k = angle_k.sin() alpha = self.rel_alpha.float().view(1, H_q, 1, self.num_frequencies) beta = self.rel_beta.float().view(1, H_q, 1, self.num_frequencies) q_cos = (alpha * cos_q - beta * sin_q) * scale q_sin = (alpha * sin_q + beta * cos_q) * scale q_prior[..., 1 : 1 + 2 * self.num_frequencies : 2] = q_cos.to(q.dtype) q_prior[..., 2 : 1 + 2 * self.num_frequencies : 2] = q_sin.to(q.dtype) k_prior[..., 1 : 1 + 2 * self.num_frequencies : 2] = cos_k.expand( B, H_kv, S_k, self.num_frequencies ).to(k.dtype) k_prior[..., 2 : 1 + 2 * self.num_frequencies : 2] = sin_k.expand( B, H_kv, S_k, self.num_frequencies ).to(k.dtype) # Last coordinate: key-only sink profile. sink_decay = F.softplus(self.sink_log_decay.float()).view(1, H_q, 1) sink_profile_k = torch.exp(-F.softplus(self.sink_log_decay.float()).mean() * c_k_f) q_prior[..., -1] = ( self.sink_strength.float().view(1, H_q, 1) * scale ).to(q.dtype) k_prior[..., -1] = sink_profile_k.expand(B, H_kv, S_k).to(k.dtype) # Value prior channels are zeros: prior affects weights, never values. v_prior = v.new_zeros(B, v.shape[1], v.shape[2], self.prior_dim) q_ext = torch.cat([q, q_prior], dim=-1).contiguous() k_ext = torch.cat([k, k_prior], dim=-1).contiguous() v_ext = torch.cat([v, v_prior], dim=-1).contiguous() return q_ext, k_ext, v_ext, self.prior_dim class NeoLLMAttention(nn.Module): """ Full attention with FANformer, RMSNorm, ResFormer, Learnable Multipliers, optional Momentum, MEA head-level composition, optional LUCID preconditioning, optional Affine-Scaled Attention, optional Exclusive Self Attention, optional Directional Routing (Taylor, 2026), optional Context Re-Positioning (Li et al., 2026), optional REPO-GRAPE contextual group positioning (Li et al., 2026 + Zhang et al., 2026), and optional REPO-GOAT factorised log-priors (Litman & Guo, 2026). Directional Routing inserts at position C — post-XSA, pre-reshape — where the output is already normalized (MEAHeadRMSNorm) and has auto-position removed (XSA). The router suppresses directions of cross-domain interference orthogonal to the self-position already cleaned by XSA. Pipeline (all active simultaneously when enabled): FANLayer → q_proj(gate) → q_norm/k_norm → REPO/RoPE or REPO-GRAPE-M → Momentum → MEA(K,V) → LUCID(V) → v_ref → optional GOAT log-prior → Affine-Scaled SDPA → MEAHeadRMSNorm → XSA → Directional Routing → reshape → o_proj · sigmoid(gate) → dropout RoPE variants (controlled by config.use_repo and layer_idx): use_repo=False (default): standard integer RoPE via pre-computed position_embeddings — identical to prior behaviour. use_repo=True, layer_idx >= repo_start_layer: REPOModule f_ϕ predicts continuous per-head positions z [B, H, S] from hidden_states. cos/sin are built inline from z and inv_freq so the rotation is differentiable w.r.t. f_ϕ parameters. use_repo=True, layer_idx < repo_start_layer: standard integer RoPE (lower layers capture surface features that benefit less from re-positioning). o_proj variants (controlled by config.use_hadamard_o_proj): False (default): dense LinearWithMultipliers — full expressivity, with Learnable Multipliers controlled by config.use_learnable_multipliers; develops high κ during training (FP8 risk). True: HadamardOProj — fixed WHT + learnable α/β, κ = 1 by construction, 25% fewer attention params, FP8-friendly (Aggarwal & Kumar, 2026, arXiv:2603.08343). References: Directional Routing: Taylor (2026). arXiv:2603.14923. Hadamard o_proj: Aggarwal & Kumar (2026). arXiv:2603.08343. Context Re-Positioning: Li et al. (2026). arXiv:2512.14391. """ def __init__(self, config: NeoLLMConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim ** -0.5 self.sqrt_head_dim = math.sqrt(self.head_dim) self.attention_dropout = config.attention_dropout self.is_causal = True self.use_momentum_attention = getattr(config, "use_momentum_attention", False) self.momentum_gamma = float(getattr(config, "momentum_gamma", 0.0)) self.use_mea_attention = getattr(config, "use_mea_attention", False) self.mea_component_key_value_heads = int( getattr(config, "mea_component_key_value_heads", config.num_key_value_heads) ) self.mea_groupnorm_eps = float( getattr(config, "mea_groupnorm_eps", config.rms_norm_eps) ) self.use_lucid_attention = getattr(config, "use_lucid_attention", False) self.lucid_attention_eps = float( getattr(config, "lucid_attention_eps", config.rms_norm_eps) ) self.use_hadamard_o_proj = getattr(config, "use_hadamard_o_proj", False) self.use_learnable_multipliers = getattr(config, "use_learnable_multipliers", True) self.fan_layer = FANLayer( hidden_size=config.hidden_size, fan_ratio=getattr(config, "fan_ratio", 0.125), ) fan_output_dim = config.hidden_size + int( config.hidden_size * getattr(config, "fan_ratio", 0.125) ) self.q_proj = LinearWithMultipliers( fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=False, enable_multipliers=self.use_learnable_multipliers, ) self.num_mea_component_heads = ( self.mea_component_key_value_heads if self.use_mea_attention else config.num_key_value_heads ) self.k_proj = nn.Linear( fan_output_dim, self.num_mea_component_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( fan_output_dim, self.num_mea_component_heads * self.head_dim, bias=config.attention_bias, ) # ── Output projection (Aggarwal & Kumar, 2026, arXiv:2603.08343) ──── # use_hadamard_o_proj=False (default): dense LinearWithMultipliers. # use_hadamard_o_proj=True: HadamardOProj — fixed WHT + learnable α/β. # κ = 1 by construction, 25% fewer attention params, FP8-friendly. # Requires hidden_size to be a power of 2 (512 ✓, 1024 ✓, 768 ✗). _o_in = config.num_attention_heads * self.head_dim if self.use_hadamard_o_proj: assert _o_in == config.hidden_size, ( f"HadamardOProj requires in_dim == out_dim, " f"got {_o_in} vs {config.hidden_size}" ) self.o_proj = HadamardOProj(config.hidden_size, bias=config.attention_bias) else: self.o_proj = LinearWithMultipliers( _o_in, config.hidden_size, bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True, enable_multipliers=self.use_learnable_multipliers, ) self.q_norm = _make_norm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = _make_norm(self.head_dim, eps=config.rms_norm_eps) if self.use_mea_attention: self.mea_key_mix = nn.Parameter( torch.eye(self.num_mea_component_heads, config.num_key_value_heads) ) self.mea_value_mix = nn.Parameter( torch.eye(self.num_mea_component_heads, config.num_key_value_heads) ) self.mea_output_norm = MEAHeadRMSNorm( num_heads=config.num_attention_heads, head_dim=self.head_dim, num_kv_groups=self.num_key_value_groups, eps=self.mea_groupnorm_eps, ) else: self.mea_key_mix = None self.mea_value_mix = None self.mea_output_norm = None self.dropout = nn.Dropout(config.dropout_rate) self.use_fan_residual = getattr(config, "use_fan_residual", True) if self.use_fan_residual: self.lambda_1 = nn.Parameter(torch.tensor(0.5)) self.lambda_2 = nn.Parameter(torch.tensor(0.5)) else: self.lambda_1 = None self.lambda_2 = None # ── Affine-Scaled Attention (Bae et al., 2026) ─────────────────────── self.use_affine_scaled_attention = getattr(config, "use_affine_scaled_attention", False) self.affine_momentum = float(getattr(config, "affine_momentum", 0.9)) # ── Exclusive Self Attention (Zhai, 2026) ──────────────────────────── self.use_xsa = getattr(config, "use_xsa", False) self.xsa_eps = float(getattr(config, "xsa_eps", 1e-6)) if self.use_affine_scaled_attention: self.alpha_proj = nn.Linear( config.hidden_size, config.num_attention_heads, bias=False ) self.register_buffer( "alpha_ma", torch.zeros(1, config.num_attention_heads, 1, 1), persistent=True, ) # ── Directional Routing (Taylor, 2026) ─────────────────────────────── # Each attention head learns K unit-norm direction vectors in head-space. # A shared 4-layer MLP router — conditioned on the mean-pooled sequence # representation — produces per-input sigmoid weights r_{h,k} ∈ [0,1] # that control how much of each directional component is suppressed from # the head's output after XSA (position C): # # o'_h = o_h - Σ_k r_{h,k} · (o_h · d_{h,k}) · d_{h,k} # # Position C (post-XSA, pre-reshape) is chosen because: # - XSA already removed auto-position (self-position noise). # - MEAHeadRMSNorm already normalized the output — suppression has # predictable magnitude since d_{h,k} is unit-norm. # - Directions live in head-space d_head before o_proj, preserving # the vocabulary projection interpretability of the paper. # - No interaction with the Gated Attention gate (applied post o_proj). # # When use_xsa=False, position C reduces to post-RMSNorm, pre-reshape — # routing still applies correctly, directions just also span the # self-position subspace (no XSA cleaned it first). # # Router: mean-pools hidden_states (pre-FAN residual stream) over S, # passes through 4-layer MLP, outputs H×K logits, temperature-scaled # sigmoid → r_{h,k}. Temperature T=5.0 pushes weights toward {0,1}. # The router is shared across all heads within this layer, exactly as # in the paper. No auxiliary loss — learns from LM objective only. # # direction_vecs: [H, K, d_head] — unit-normalized in forward, not init. # Initialized from normal(0, 1) and normalized at first forward pass. self.use_directional_routing = getattr(config, "use_directional_routing", False) self.directional_routing_k = int(getattr(config, "directional_routing_k", 4)) self.directional_routing_temp = float(getattr(config, "directional_routing_temp", 5.0)) if self.use_directional_routing: H = config.num_attention_heads K = self.directional_routing_k D = self.head_dim R = config.hidden_size # router hidden dim # Direction vectors: [H, K, d_head] # Stored unnormalized — unit-normalized during forward. self.direction_vecs = nn.Parameter( torch.randn(H, K, D) ) # 4-layer MLP router shared across heads within this layer. # Input: mean-pooled hidden_states [B, hidden_size] # Output: [B, H*K] → reshape [B, H, K] → sigmoid(T·x) → r_{h,k} # Intermediate dim = hidden_size throughout, matching the paper. self.direction_router = nn.Sequential( nn.RMSNorm(R, eps=config.rms_norm_eps), nn.Linear(R, R, bias=True), nn.GELU(), nn.Linear(R, R, bias=True), nn.GELU(), nn.Linear(R, R, bias=True), nn.GELU(), nn.Linear(R, H * K, bias=True), ) else: self.direction_vecs = None self.direction_router = None # ── Context Re-Positioning / REPO-GRAPE ───────────────────────────── # REPO-GRAPE reuses the REPO coordinate module f_phi and replaces only # the positional action applied to Q/K. Enabling use_repo_grape=True # therefore activates the REPO coordinate path even when use_repo=False. _repo_requested = bool(getattr(config, "use_repo", False)) _repo_grape_requested = bool(getattr(config, "use_repo_grape", False)) _repo_start_layer = getattr(config, "repo_start_layer", config.num_hidden_layers // 3) self.use_repo = ( (_repo_requested or _repo_grape_requested) and layer_idx >= _repo_start_layer ) self.use_repo_grape = _repo_grape_requested and self.use_repo if self.use_repo: _d_p = getattr(config, "repo_d_p", config.hidden_size // 8) self.repo_module = REPOModule( hidden_size=config.hidden_size, d_p=_d_p, num_heads=config.num_attention_heads, ) else: self.repo_module = None if self.use_repo_grape: self.repo_grape = RepoGrapePositioning( config=config, layer_idx=layer_idx, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=self.head_dim, ) else: self.repo_grape = None # ── REPO-GOAT factorised log-prior ─────────────────────────────────── # Independent of use_repo/use_repo_grape: it can be ablated with standard # RoPE or composed with REPO-GRAPE-M. It activates only from # repo_start_layer onward to mirror the REPO positional stack schedule. self.use_repo_goat_prior = ( bool(getattr(config, "use_repo_goat_prior", False)) and layer_idx >= _repo_start_layer ) if self.use_repo_goat_prior: self.repo_goat_prior = RepoGoatPrior( config=config, layer_idx=layer_idx, num_attention_heads=config.num_attention_heads, head_dim=self.head_dim, ) else: self.repo_goat_prior = None # ── Interleaved Head Attention (Duvvuri et al., 2026) ───────────────── # Implementa cross-head mixing antes de RoPE: cada pseudo-head h,p # es combinación lineal aprendida de los H input-heads originales. # P=1: solo mixing (misma forma), compatible con todos los flags. # P>1: expande la secuencia real S → S·P y colapsa vía R después # de SDPA. num_key_value_groups se preserva porque Q y KV se # expanden en la dimensión de secuencia, no en la de heads. # # Schedule local-global: # 'L' → capa local con IHA y ventana W=N/(2P²). # 'G' → capa global de transporte. Por defecto NO usa IHA, porque # el paper empareja FLOPs como 4 locales IHA + 1 global estándar. # Activa config.iha_global_layers_use_iha=True solo si quieres # la ablación más cara donde la G también ejecuta global IHA. # # Si una capa 'G' no usa IHA, no se instancian iha_alpha_q/k/v ni R # en esa capa. Así el flag desactiva compute, parámetros y grafo. # Init identidad: IHA ≡ MHA en paso 0 (Teorema 2, inclusión M⊆P_P). # Referencia: Duvvuri et al. (2026), arXiv:2602.21371. self.iha_P = int(getattr(config, "iha_num_pseudo_heads", 1)) _base_use_iha = bool(getattr(config, "use_iha", False)) _pattern = list(getattr(config, "iha_local_global_pattern", "LLLLG").upper()) _pos = layer_idx % max(len(_pattern), 1) self.iha_schedule_token = _pattern[_pos] if _pattern else "G" self.iha_global_layers_use_iha = bool( getattr(config, "iha_global_layers_use_iha", False) ) # Per-layer activation: # local slots always use IHA when global use_iha=True; # global slots use IHA only under the explicit ablation flag. self.iha_layer_uses_iha = bool( _base_use_iha and ( self.iha_schedule_token == "L" or (self.iha_schedule_token == "G" and self.iha_global_layers_use_iha) ) ) self.use_iha = self.iha_layer_uses_iha if self.use_iha: _H_q = config.num_attention_heads _H_kv = self.num_mea_component_heads _P = self.iha_P # alpha_q[h_out, h_in, p]: mezcla de h_in sobre todos los heads # → pseudo p de head h_out. # En K/V usamos H_comp (pre-MEA), de modo que MEA pueda recomponer # H_comp → H_kv independientemente dentro de cada pseudo-slot. self.iha_alpha_q = nn.Parameter(torch.zeros(_H_q, _H_q, _P)) self.iha_alpha_k = nn.Parameter(torch.zeros(_H_kv, _H_kv, _P)) self.iha_alpha_v = nn.Parameter(torch.zeros(_H_kv, _H_kv, _P)) # R[h, p]: colapsa el slot p del pseudo-token hacia head h. # Shape [H_q, P] — colapso sobre la dimensión de pseudo-slots en # la secuencia expandida (paper Alg. 1 Step 8: einsum 'hp,hnpd→hnd'). # Init identidad: slot p=0 pasa íntegro, resto cero → M⊆P_P (Teo. 2). self.iha_R = nn.Parameter(torch.zeros(_H_q, _P)) # Init identidad: IHA ≡ MHA en paso 0 (Teorema 2, inclusión M⊆P_P). with torch.no_grad(): for _h in range(_H_q): self.iha_alpha_q.data[_h, _h, 0] = 1.0 self.iha_R.data[_h, 0] = 1.0 # slot 0 → identidad for _h in range(_H_kv): self.iha_alpha_k.data[_h, _h, 0] = 1.0 self.iha_alpha_v.data[_h, _h, 0] = 1.0 # ── Schedule local-global (paper §5.1 / Appendix C) ────────────── # El pattern "LLLLG" repite cada len(pattern) capas: # 'L' → capa local IHA con sliding window (FLOP-cheap). # 'G' → capa global. Con iha_global_layers_use_iha=False esta # rama no llega aquí, porque la capa G no instancia IHA. # Con True, la G sí llega aquí y usa IHA global full attention. # Solo relevante cuando P>1; con P=1 no hay expansión. self.iha_is_local = ( self.iha_P > 1 and self.iha_schedule_token == "L" ) # Tamaño de ventana W para capas locales. # Si el usuario fija iha_sliding_window, se usa ese valor. # Si no, se resuelve en forward como W = N/(2P²) usando la # longitud real S del batch, siguiendo la formulación exacta # del paper para el schedule local-global. _cfg_window = getattr(config, "iha_sliding_window", None) self.iha_window = ( int(_cfg_window) if _cfg_window is not None else None ) # ── Flash-attn instance references (P>1 only) ───────────────── # Stored as plain attributes (not nn.Parameters/buffers) so # Dynamo sees them as compile-time constants — no global mutation # inside forward, no graph break. # Raise at __init__ time (not forward time) so the error surfaces # immediately when the model is constructed, not at the first step. if self.iha_P > 1: if not _IHA_FLASH_ATTN_AVAILABLE: raise ImportError( "Interleaved Head Attention with P>1 (seq-expand / paper-correct " "mode) requires flash_attn to be explicitly installed.\n" "Install it with:\n" " pip install flash-attn --no-build-isolation\n\n" "If flash_attn is not available, set iha_num_pseudo_heads=1 " "(P=1) to fall back to the head-expand approximation, which " "uses the standard HF attention backend." ) self._iha_flash_attn_func = _IHA_FA_FUNC self._iha_flash_attn_varlen = _IHA_FA_VARLEN else: self._iha_flash_attn_func = None self._iha_flash_attn_varlen = None else: self.iha_is_local = False self.iha_window = None self._iha_flash_attn_func = None self._iha_flash_attn_varlen = None def _resolve_iha_window(self, seq_len: int) -> int: """ Resolve the local IHA sliding window for the current sequence length. Exact paper recipe (Sec. 5.1 / Appendix C): W := N / (2 P^2) where N is the actual sequence length seen by the layer and P is the number of pseudo-heads per head. When a manual window is provided in config, that explicit value takes precedence. """ if self.iha_window is not None: return max(1, min(int(self.iha_window), seq_len)) denom = max(2 * self.iha_P * self.iha_P, 1) return max(1, seq_len // denom) def _build_iha_local_mask( self, seq_len: int, window_size: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """ Construye máscara aditiva de ventana deslizante [1, 1, S, S]. Implementa la restricción de la capa IHA local del paper (§5.1): cada token i solo atiende a los últimos `window_size` tokens previos (i − W ≤ j ≤ i). Se SUMA a la máscara causal existente, combinando causalidad + restricción de banda en un único tensor. Compatibilidad con backends (paper §2, FlashAttention §A): - eager : se suma a attn_weights antes del softmax → -inf → exp=0 - sdpa : misma interfaz que eager vía scaled_dot_product_attention - flash2/3: la localidad se expresa con `sliding_window` del backend; este helper solo se usa en rutas que sí consumen una máscara aditiva 4D directamente. Window auto follows the exact paper schedule: W := N / (2P²) with N = current sequence length. This is the recipe used by the local-global FLOP-matched schedule in Sec. 5.1 / Appendix C. Args: seq_len: S — longitud de la secuencia actual. window_size: W — tokens hacia atrás que cada query puede atender. device, dtype: del tensor Q para evitar copias de dispositivo. Returns: [1, 1, S, S] float — 0 en posiciones válidas, -inf en posiciones fuera de ventana (j < i − W). """ # i [S,1], j [1,S] → out-of-window where j < i - W i = torch.arange(seq_len, device=device).unsqueeze(1) # [S, 1] j = torch.arange(seq_len, device=device).unsqueeze(0) # [1, S] mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype) mask[j < (i - window_size)] = float("-inf") return mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S] def _apply_iha_pseudo_heads( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ IHA Steps 2-4 (Algorithm 1, Duvvuri et al. 2026, arXiv:2602.21371). Implements the paper-correct **seq-expand** approach: pseudo-heads are merged into the *sequence* dimension, yielding H heads each attending over S·P virtual tokens rather than the approximation of creating H·P independent heads over S tokens. Step 2 — Pseudo-head mixing across original heads (Alg. 1 lines 2-4): Q̃[h,p,s,:] = Σ_m α_Q[h,m,p] · Q[m,s,:] Step 3 — Interleave P pseudo-slots into sequence dimension (line 5): For each head h: sequence becomes (s=0,p=0), (s=0,p=1), …, (s=0,p=P-1), (s=1,p=0), …, (s=S-1,p=P-1) Virtual position of (s,p) is s·P + p. This is what enables up to P² distinct attention patterns per head (pseudo-query p₁ can attend to pseudo-key p₂ at the *same* original position s, which is impossible in the head-expand approximation). Compatibility with FlashAttention is preserved because Step 4 is standard scaled dot-product attention — no custom kernel needed. Requires flash_attn (imported lazily) for P>1; checked at forward call time, not here, so the function itself stays pure-PyTorch. P=1 Squeeze the unit P dimension — output shapes identical to input. Functionally equivalent to a learned linear recombination of heads (no seq expansion, no new parameters beyond alpha). P>1 (seq-expand) Q: [B, H_q, S, d] → [B, H_q, S·P, d] K/V: [B, H_kv, S, d] → [B, H_kv, S·P, d] GQA ratio H_q/H_kv is invariant (both expand in the seq dim). Args ---- q: [B, H_q, S, d_head] k: [B, H_kv, S, d_head] v: [B, H_kv, S, d_head] Returns ------- q_out: [B, H_q, S·P, d_head] (= [B, H_q, S, d] if P=1) k_out: [B, H_kv, S·P, d_head] (= [B, H_kv, S, d] if P=1) v_out: [B, H_kv, S·P, d_head] (= [B, H_kv, S, d] if P=1) """ P = self.iha_P B, H_q, S, d = q.shape H_kv = k.shape[1] # Step 2: learned linear combination across original heads. # alpha[h_out, h_in, p] × q[b, h_in, s, d] → [B, H_q, P, S, d] # einsum contracts h_in (index m) and distributes over pseudo-slot p. q_p = torch.einsum("hmp,bmsd->bhpsd", self.iha_alpha_q, q) # [B,H_q, P,S,d] k_p = torch.einsum("hmp,bmsd->bhpsd", self.iha_alpha_k, k) # [B,H_kv,P,S,d] v_p = torch.einsum("hmp,bmsd->bhpsd", self.iha_alpha_v, v) # [B,H_kv,P,S,d] if P == 1: # Squeeze unit P axis — shapes unchanged, all downstream paths intact. q_out = q_p.squeeze(2) # [B, H_q, S, d] k_out = k_p.squeeze(2) # [B, H_kv, S, d] v_out = v_p.squeeze(2) # [B, H_kv, S, d] else: # Step 3: interleave P pseudo-slots into the sequence dimension. # Layout of q_p: [B, H, P, S, d] # Desired: [B, H, S·P, d] with virtual pos s·P+p # # .transpose(2, 3) → [B, H, S, P, d] (swap P and S axes) # .reshape(...) → [B, H, S·P, d] (fuse S and P into one axis) # # The resulting index is: out[b, h, s·P+p, d] = q_p[b, h, p, s, d] # which is exactly the interleaving the paper describes. q_out = q_p.transpose(2, 3).reshape(B, H_q, S * P, d) k_out = k_p.transpose(2, 3).reshape(B, H_kv, S * P, d) v_out = v_p.transpose(2, 3).reshape(B, H_kv, S * P, d) return q_out, k_out, v_out def _apply_iha_collapse(self, attn_out: torch.Tensor) -> torch.Tensor: """ IHA Step 5 (Algorithm 1, line 7-8, Duvvuri et al. 2026). Collapses the P pseudo-slot axis back to one representation per real token position using the learned R matrix. **Paper formulation (seq-expand):** reshape(O, [H, N, P, d]) → (H, N, P, d) einsum('hp, hnpd → hnd', R, O) → (H, N, d) where R ∈ ℝ^{H×P} weights the contribution of each pseudo-slot p to the final head-h output. This is *not* a collapse over heads — it is a collapse over the P interleaved virtual tokens at each real position. Initialization (identity / Theorem 2 inclusion M ⊆ P_P): R[h, 0] = 1.0, R[h, p>0] = 0.0 for all h so at step 0 the output equals the p=0 pseudo-slot, which by the alpha initialization equals the original MHA head output. Args ---- attn_out : [B, S·P, H_q, d_head] Flash-attention output in HF layout (batch, seq, heads, dim). S·P is the expanded virtual sequence length. Returns ------- [B, S, H_q, d_head] One representation per real token, one per original head. """ B, SP, H_q, d = attn_out.shape P = self.iha_P S = SP // P # Step 7: separate the interleaved P pseudo-slots from the N real positions. # [B, S·P, H, d] → [B, S, P, H, d] out_struct = attn_out.reshape(B, S, P, H_q, d) # Step 8: weighted sum over pseudo-slots. # R: [H_q, P] einsum index p → 'hp,bsphd → bshd' return torch.einsum("hp,bsphd->bshd", self.iha_R, out_struct) # ── IHA seq-expand helpers ──────────────────────────────────────────────── def _build_iha_interleaved_rope( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, inv_freq: torch.Tensor, attention_scaling: float, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Build interleaved RoPE for the seq-expanded IHA sequence and apply it. The paper (§4, interleaving note) requires that virtual token (s, p) receives integer position s·P + p, giving each pseudo-slot a distinct rotary phase. This cannot be done with the precomputed cos/sin (which cover positions 0…S-1); we rebuild cos/sin inline from inv_freq exactly as _apply_repo_rope does for the REPO path. Interleaved position of virtual token (s, p): pos_iha[b, s·P + p] = position_ids[b, s] · P + p Construction: position_ids_iha[B, S·P]: • Expand position_ids[B, S] → [B, S, 1] · P + arange(P)[1,1,P] → [B, S, P], then reshape → [B, S·P] cos/sin computed inline (no autocast, float32 trig): freqs[B, S·P, rot_dim/2] = pos_iha[:, :, None] · inv_freq[None, None, :] emb = cat(freqs, freqs) → [B, S·P, rot_dim] cos, sin scaled by attention_scaling apply_rotary_pos_emb is then called with q[B, H_q, S·P, d] and k[B, H_kv, S·P, d] using the expanded cos/sin[B, S·P, rot_dim]. Args ---- q, k : already seq-expanded, [B, H, S·P, d] position_ids: original [B, S] integer positions (handles packing) inv_freq: rotary frequency vector [rot_dim/2] (from repo_rope_args) attention_scaling: float scale (from repo_rope_args) Returns ------- (q_rot, k_rot) with same shapes as input. """ # Don't read batch size B from q here — position_ids may have # pid_B=1 (broadcast-ready) while q has B=batch_size. # apply_rotary_pos_emb handles the broadcast internally. SP = q.shape[2] P = self.iha_P S = SP // P # ── Normalize position_ids to exactly 2-D [pid_B, S] ───────────────── # Trainers can pass position_ids in three shapes: # [S] 1-D: single sequence, no batch dim # [pid_B, S] 2-D: standard (may be pid_B=1 broadcast or pid_B=B) # [pid_B, S, k] 3-D: e.g. [1, S, 2] when trainers attach per-doc # metadata (doc_id, pos) or (cos_pos, sin_pos); # only the sequential index (dim -1 first col) is # needed for RoPE. # All three are normalized to 2-D so the arithmetic below is stable. pid = position_ids if pid.dim() == 1: pid = pid.unsqueeze(0) # [1, S] elif pid.dim() > 2: pid = pid[..., 0] # [pid_B, S] — take sequential dim # Build interleaved position indices [pid_B, S·P] # Virtual token (s, p) at index s·P+p receives position pid[b,s]·P + p. # p_offsets[P] = [0, 1, …, P-1] # pid.unsqueeze(-1) → [pid_B, S, 1] # * P + p_offsets.view(1, 1, P) → [pid_B, S, P] # reshape(-1, SP) → [pid_B, S·P] (-1 keeps pid_B intact even when 1) p_offsets = torch.arange(P, device=pid.device, dtype=pid.dtype) # [P] pos_iha = (pid.unsqueeze(-1) * P + p_offsets.view(1, 1, P) ).reshape(-1, SP) # [pid_B, S·P] # Compute cos/sin inline from inv_freq. # No torch.no_grad() wrapper: position_ids / inv_freq have no grad # (int64 and frozen buffer respectively), so no gradient path is # created regardless. Wrapping in no_grad() inside a compiled region # causes a Dynamo graph break (it cannot inline autograd context # managers mid-graph). inv_freq_f = inv_freq.to(dtype=torch.float32) freqs = pos_iha.float().unsqueeze(-1) * inv_freq_f # [pid_B, S·P, r/2] emb = torch.cat([freqs, freqs], dim=-1) # [pid_B, S·P, r] cos = (emb.cos() * attention_scaling).to(q.dtype) # [pid_B, S·P, r] sin = (emb.sin() * attention_scaling).to(q.dtype) # apply_rotary_pos_emb adds unsqueeze(1) internally for the head dim # and broadcasts over the batch dim, so pid_B=1 works for any B. q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin) return q_rot, k_rot def _iha_seq_expand_flash_forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alpha: Optional[torch.Tensor], beta: Optional[torch.Tensor], use_affine: bool, dropout: float, S_orig: int, position_ids: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, None]: """ Flash attention for the IHA seq-expand path (P>1). Bypasses the HF flash-attention wrapper entirely to avoid the three blockers identified in the architecture analysis: • attention_mask truncation in _upad_input (key silently clipped to S) • cu_seqlens / packing incompatibility (_prepare_from_posids shape mismatch) • attention_mask→additive-mask conversion that HF injects (unsupported for seq length S·P when the mask is shaped for S) Calls flash_attn_func (dense path) or flash_attn_varlen_func (packed path) directly. Affine-Scaled Attention is also computed inline here so that Term 2 sums V over the same valid keys as flash attention on the S·P expanded sequence. Packing detection ----------------- Packing is assumed when batch_size == 1 AND position_ids is not a simple 0…S-1 range (i.e., contains document-boundary resets). This mirrors HF _is_packed_sequence logic. When packing is detected, cu_seqlens_iha are derived by multiplying per-document lengths by P. Sliding window -------------- For local IHA layers (iha_is_local=True), the window in the expanded sequence is W·P where W = S_orig // (2·P²) (paper Appendix C). Affine-Scaled Attention (inline) --------------------------------- When use_affine=True: output = α · flash_out + β · Σ_{j∈A(i)} V_j α, β are [B, H_q, S·P, 1] (already expanded over the seq dim before this call). The V sum is causal prefix-sum for global IHA, windowed for local IHA, and segmented when packed position_ids reset. Args ---- q, k, v : [B, H_q/kv, S·P, d] — after IHA expansion and RoPE alpha, beta: [B, H_q, S·P, 1] or None use_affine : whether Affine-Scaled Attention is active dropout : attention dropout probability S_orig : original (pre-expand) sequence length S position_ids: [B, S] or None — for packing detection + cu_seqlens Returns ------- (attn_out, None) attn_out: [B, S·P, H_q, d_head] (flash layout — batch, seq, heads, dim) """ # flash_attn functions are stored as instance attributes in __init__ # (self._iha_flash_attn_func / self._iha_flash_attn_varlen) so Dynamo # sees them as compile-time constants — no global lookup, no graph break. flash_attn_fn = self._iha_flash_attn_func flash_varlen_fn = self._iha_flash_attn_varlen P = self.iha_P B, H_q, SP, d_head = q.shape H_kv = k.shape[1] # flash_attn layout: [B, seq, heads, dim] q_fa = q.transpose(1, 2).contiguous() # [B, S·P, H_q, d] k_fa = k.transpose(1, 2).contiguous() # [B, S·P, H_kv, d] v_fa = v.transpose(1, 2).contiguous() # [B, S·P, H_kv, d] # ── Sliding window (local IHA layers) ──────────────────────────────── # Paper Appendix C: W := N / (2·P²) in original-sequence tokens. # In the expanded sequence each query can see W·P virtual keys to the # left. flash_attn window_size=(left, right) in virtual-token units. if self.iha_is_local: W_orig = max(1, S_orig // max(2 * P * P, 1)) window = (W_orig * P, 0) else: window = (-1, -1) # ── Packing detection ───────────────────────────────────────────────── # Packing: batch_size==1 and position_ids contains document-boundary # resets (non-monotonic positions). Mirrors HF _is_packed_sequence. # # position_ids may be 3-D [1, S, k] from some trainers; normalise to # 1-D [S] for the monotonicity check by taking the first column. _pos_for_pack = None if position_ids is not None and B == 1: _p = position_ids[0] # [S] or [S, k] if _p.dim() > 1: _p = _p[:, 0] # [S] — take sequential index column _pos_for_pack = _p.long() # [S] # .all().item() would break the compiled graph; use bool() which Dynamo # can evaluate at trace time when the tensor is a compile-time constant, # and falls back to a graph-break-free path otherwise. is_packed = ( _pos_for_pack is not None and not bool((torch.diff(_pos_for_pack.float()) >= 0).all()) ) # ── Packed / varlen path ────────────────────────────────────────────── if is_packed: # Derive cu_seqlens from document boundaries in the 1-D position # sequence. A new document starts wherever the position index resets # (pos[i] < pos[i-1]). pos0 = _pos_for_pack # [S] # Boundary mask: True at positions that start a new document. boundaries = torch.cat([ torch.ones(1, device=pos0.device, dtype=torch.bool), pos0[1:] < pos0[:-1], # reset ]) doc_starts = boundaries.nonzero(as_tuple=False).squeeze(1) # indices doc_lengths = torch.diff( torch.cat([doc_starts, torch.tensor([S_orig], device=pos0.device)]) ) # [num_docs] # Scale each document length by P for the expanded sequence. doc_lengths_p = doc_lengths * P # [num_docs] cu_seqlens = torch.zeros( len(doc_lengths) + 1, device=q.device, dtype=torch.int32 ) cu_seqlens[1:] = doc_lengths_p.cumsum(0).to(torch.int32) max_seqlen_p = int(doc_lengths_p.max().item()) # Flatten batch dim (B=1 for packing): [1, S·P, H, d] → [S·P, H, d] q_flat = q_fa.squeeze(0) k_flat = k_fa.squeeze(0) v_flat = v_fa.squeeze(0) out_flat = flash_varlen_fn( q_flat, k_flat, v_flat, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen_p, max_seqlen_k=max_seqlen_p, dropout_p=dropout if self.training else 0.0, softmax_scale=self.scaling, causal=True, window_size=window, ) # [S·P, H_q, d] flash_out = out_flat.unsqueeze(0) # [1, S·P, H_q, d] # ── Dense (non-packed) path ─────────────────────────────────────────── else: flash_out = flash_attn_fn( q_fa, k_fa, v_fa, dropout_p=dropout if self.training else 0.0, softmax_scale=self.scaling, causal=True, window_size=window, ) # [B, S·P, H_q, d] # ── Affine-Scaled Attention inline (flash path) ─────────────────────── # Decomposition used here: # Term 1: α · flash_out (standard flash output) # Term 2: β · Σ_{j≤i} V_j (causal prefix value summary) # This intentionally restores the original local-IHA behavior: the # attention logits may use a sliding window, but the β branch supplies a # cheap global causal value context. For packed IHA, the prefix is # reset at each document boundary to avoid cross-document leakage. # α, β arrive pre-expanded to [B, H_q, S·P, 1]; permute to [B,S·P,H,1] # to match the flash output layout [B, S·P, H_q, d]. if use_affine: # Compute the β value-sum in KV-head space first. For GQA this is # exact because the valid-key sum is linear and independent of the # repeated query-head view: # Sum(repeat_kv(V)) == repeat_kv(Sum(V)). # The broadcast to H_q happens only in the final affine combine. # β uses the full causal prefix, not the local flash window. # Keep the sum in KV-head space first; GQA expansion happens only # in the final affine combine below. if is_packed: value_sum_kv = _segmented_causal_value_sum( v, segment_lengths=doc_lengths_p, window_left=None, ) # [1,H_kv,S·P,d] else: value_sum_kv = _causal_value_sum( v, query_len=SP, window_left=None, ) # [B,H_kv,S·P,d] alpha_t = alpha.permute(0, 2, 1, 3) # [B,S·P,H_q,1] beta_t = beta.permute(0, 2, 1, 3) # [B,S·P,H_q,1] value_sum_t = value_sum_kv.transpose(1, 2) # [B,S·P,H_kv,d] if self.num_key_value_groups == 1 or H_q == H_kv: attn_out = alpha_t * flash_out + beta_t * value_sum_t else: flash_g = flash_out.reshape(B, SP, H_kv, self.num_key_value_groups, d_head) alpha_g = alpha_t.reshape(B, SP, H_kv, self.num_key_value_groups, 1) beta_g = beta_t.reshape(B, SP, H_kv, self.num_key_value_groups, 1) attn_out = ( alpha_g * flash_g + beta_g * value_sum_t.unsqueeze(3) ).reshape(B, SP, H_q, d_head) attn_out = torch.nn.functional.dropout( attn_out, p=dropout, training=self.training ) else: attn_out = flash_out return attn_out, None def _apply_momentum_attention( self, q: torch.Tensor, k: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: if not self.use_momentum_attention or self.momentum_gamma == 0.0: return q, k dq = causal_first_difference(q) dk = causal_first_difference(k) q_new = q + self.momentum_gamma * dq k_new = k + self.momentum_gamma * dk return q_new, k_new def _apply_mea_head_mixing( self, k: torch.Tensor, v: torch.Tensor, _force_standard: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: if not self.use_mea_attention: return k, v if self.use_iha and self.iha_P > 1 and not _force_standard: # Preserve the IHA pseudo-slot factorization: MEA mixes the KV # component heads inside each pseudo-slot, but never mixes # information across different pseudo indices. # NOTE: this branch is reached only by the head-expand path (P>1 # but _force_standard=False). The seq-expand path (P>1) calls # with _force_standard=True and falls through to the standard branch. k_mixed = head_linear_compose_pseudo( k, self.mea_key_mix, self.iha_P ).contiguous() v_mixed = head_linear_compose_pseudo( v, self.mea_value_mix, self.iha_P ).contiguous() else: # Standard path — also used by IHA seq-expand (P>1, _force_standard=True). # K/V may be [B, H_kv, S·P, d]; head_linear_compose operates on the # head dimension and leaves the sequence dimension untouched. k_mixed = head_linear_compose(k, self.mea_key_mix).contiguous() v_mixed = head_linear_compose(v, self.mea_value_mix).contiguous() return k_mixed, v_mixed def _apply_lucid_preconditioner( self, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: if not self.use_lucid_attention: return v.contiguous() key_rn = rms_key_unit_norm(k, eps=self.lucid_attention_eps) logits = torch.matmul(key_rn, key_rn.transpose(-1, -2)) * self.scaling - self.sqrt_head_dim prec = torch.tril(torch.exp(logits)) kv = infer_key_validity(attention_mask, k.shape[-2], k.shape[1]) if kv is not None: prec = prec * (kv.unsqueeze(-1) & kv.unsqueeze(-2)).to(prec.dtype) eye = torch.eye(prec.shape[-1], device=prec.device, dtype=prec.dtype).view( 1, 1, prec.shape[-1], prec.shape[-1] ) prec = prec + eye * (1.0 - prec.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)) result = torch.linalg.solve_triangular( prec, v.float(), upper=False, unitriangular=True ).to(v.dtype).contiguous() return result def _apply_directional_routing( self, attn_out: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: """ Directional suppression at position C (post-XSA, pre-reshape). Args: attn_out: [B, S, H, d_head] — output after XSA and RMSNorm. hidden_states: [B, S, hidden_size] — pre-FAN residual stream, used as router input (same as paper's x_i). Returns: [B, S, H, d_head] with selected directional components suppressed. """ # ── Router: one routing decision per sequence ───────────────────── # Mean-pool over sequence dimension S → [B, hidden_size]. # The router is sequence-level (not token-level), matching the paper. # This produces one suppression pattern per input sequence, # which the paper shows learns domain-adaptive behavior in early # layers and fixed syntactic pruning in late layers. pooled = hidden_states.mean(dim=1) # [B, hidden_size] logits = self.direction_router(pooled) # [B, H*K] r = torch.sigmoid(self.directional_routing_temp * logits) r = r.view(hidden_states.shape[0], self.config.num_attention_heads, self.directional_routing_k) # [B, H, K] # Expand over sequence for broadcasting: [B, 1, H, K] r = r.unsqueeze(1) # ── Unit-normalize direction vectors ────────────────────────────── # Normalize at forward time, not at init, following the paper. # d: [H, K, d_head] → unit norm along d_head dimension. d = F.normalize(self.direction_vecs, dim=-1) # [H, K, d_head] # ── Directional suppression ─────────────────────────────────────── # attn_out: [B, S, H, d_head] # For each head h and direction k: # proj_{h,k} = (o_h · d_{h,k}) scalar per (B, S, H, K) # suppress = r_{h,k} · proj_{h,k} · d_{h,k} # o'_h = o_h - Σ_k suppress_{h,k} # # proj: einsum over d_head dimension # attn_out [B, S, H, D] × d [H, K, D] → [B, S, H, K] proj = torch.einsum("bshd,hkd->bshk", attn_out, d) # [B, S, H, K] # r [B, 1, H, K] × proj [B, S, H, K] → [B, S, H, K] weighted = r * proj # [B, S, H, K] # Σ_k weighted_{h,k} · d_{h,k}: # weighted [B, S, H, K] × d [H, K, D] → [B, S, H, D] suppression = torch.einsum("bshk,hkd->bshd", weighted, d) result = attn_out - suppression return result def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, first_layer_fan: Optional[torch.Tensor] = None, repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None, position_ids: Optional[torch.LongTensor] = None, key_value_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] if key_value_states is None: key_states = hidden_states value_states = hidden_states else: key_states, value_states = key_value_states h_fan_q = self.fan_layer(hidden_states) if key_value_states is None: h_fan_k = h_fan_q h_fan_v = h_fan_q else: h_fan_k = self.fan_layer(key_states) h_fan_v = self.fan_layer(value_states) if self.use_fan_residual and first_layer_fan is not None: h_fan_q = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan_q h_fan_k = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan_k h_fan_v = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan_v current_layer_fan = h_fan_q.clone() query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim) kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim) q_raw, gate = torch.chunk( self.q_proj(h_fan_q).view(*input_shape, self.config.num_attention_heads, self.head_dim * 2), 2, dim=-1, ) gate = gate.reshape(*input_shape, -1) q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2) k = self.k_norm(self.k_proj(h_fan_k).view(kv_shape)).transpose(1, 2) v = self.v_proj(h_fan_v).view(kv_shape).transpose(1, 2) # ── IHA: cross-head pseudo-head mixing (Duvvuri et al., 2026) ──────── # Position: post active Q/K norm (RMSNorm or RMSNorm), pre-RoPE/REPO. # Applies learned linear combination of all H heads into pseudo-heads # before positional encoding (paper Alg. 1, Steps 2-4). # # P=1 → shapes unchanged; all downstream paths fully intact. # P>1 → seq-expand (paper-correct) only on layers where self.use_iha=True: # Q : [B, H_q, S, d] → [B, H_q, S·P, d] # K/V : [B, H_kv, S, d] → [B, H_kv, S·P, d] # Virtual token (s, p) is at index s·P+p in the expanded dim. # GQA ratio H_q / H_kv is invariant (both expand in seq, not heads). # For pattern token 'G' with iha_global_layers_use_iha=False, # self.use_iha=False in this layer and the standard global path runs. _iha_seq_expand = self.use_iha and self.iha_P > 1 _S_orig = q.shape[2] # original sequence length before any IHA expansion if self.use_iha: q, k, v = self._apply_iha_pseudo_heads(q, k, v) # ── RoPE / REPO ─────────────────────────────────────────────────────── cos, sin = position_embeddings if self.use_repo: # REPO path: f_ϕ predicts continuous per-head positions from the # residual stream, then cos/sin are built inline from those positions # so the rotation is differentiable w.r.t. REPOModule parameters. # inv_freq and attention_scaling arrive via repo_rope_args, sourced # directly from rotary_emb at forward time — no buffer on this module, # no meta-tensor issue on lm_eval / to(device) paths. # (Li et al., 2026, §3.2 — Eq. 6–7) z = self.repo_module(hidden_states) # [B, H, S] if self.use_repo_grape: z = self.repo_grape.transform_positions(z) inv_freq, attn_scaling = repo_rope_args if _iha_seq_expand: # IHA seq-expand + REPO/REPO-GRAPE: each real position s has P virtual slots. # Assign continuous position z[b,h,s]·P + p to slot (s,p), scaling z by P # to create room for P sub-positions while preserving relative ordering. P = self.iha_P p_offsets = torch.arange(P, device=z.device, dtype=z.dtype) # z[B,H,S] → [B,H,S,1]*P + [1,1,1,P] → [B,H,S,P] → [B,H,S·P] z_expanded = (z.unsqueeze(-1) * P + p_offsets.view(1, 1, 1, P) ).reshape(z.shape[0], z.shape[1], _S_orig * P) if self.use_repo_grape: q, k = self.repo_grape.apply_multiplicative( q, k, z_expanded, inv_freq, attn_scaling) else: q, k = _apply_repo_rope(q, k, z_expanded, inv_freq, attn_scaling) else: # Standard REPO path (P=1 or no IHA): z covers original S positions. if self.use_repo_grape: q, k = self.repo_grape.apply_multiplicative( q, k, z, inv_freq, attn_scaling) else: q, k = _apply_repo_rope(q, k, z, inv_freq, attn_scaling) elif _iha_seq_expand: # IHA seq-expand + standard integer RoPE: # Build interleaved positions pos_iha[s·P+p] = position_ids[s]·P + p # and compute cos/sin inline so that each pseudo-slot receives a # distinct rotary phase — paper §4 interleaving note. # inv_freq is available via repo_rope_args (NeoLLMModel always passes # it when use_iha=True and P>1, regardless of use_repo). inv_freq, attn_scaling = repo_rope_args _pid = position_ids if _pid is None: # Fall back to 0…S-1 when position_ids were not threaded through. _pid = torch.arange(_S_orig, device=q.device).unsqueeze(0).expand( q.shape[0], -1 ) q, k = self._build_iha_interleaved_rope(q, k, _pid, inv_freq, attn_scaling) else: # Standard path: integer positions pre-computed by NeoLLMModel. q, k = apply_rotary_pos_emb(q, k, cos, sin) q, k = self._apply_momentum_attention(q, k) # ── MEA head mixing ─────────────────────────────────────────────────── # For IHA seq-expand (P>1): K/V are [B, H_kv, S·P, d] — same number of # heads as without IHA, just a longer sequence. Use the standard # head_linear_compose path; the pseudo-slot-aware path (head_linear_compose_pseudo) # was designed for the head-expand layout and would mismatch here. # For P=1 or no IHA: the existing branch logic in _apply_mea_head_mixing # is unchanged and handles both the pseudo-slot-aware and normal cases. if _iha_seq_expand: k, v = self._apply_mea_head_mixing(k, v, _force_standard=True) else: k, v = self._apply_mea_head_mixing(k, v) v = self._apply_lucid_preconditioner(k, v, attention_mask) # Capture v_ref for XSA after MEA mixing and LUCID preconditioning. # This is the vector that actually participated in SDPA aggregation. v_ref = v if self.use_xsa else None # ── REPO-GOAT factorised log-prior ─────────────────────────────────── # Appends Q/K prior channels and zero V channels immediately before the # attention backend. The output is sliced back to head_dim after SDPA, # so all downstream modules remain shape-compatible. repo_goat_prior_dim = 0 if self.use_repo_goat_prior: q, k, v, repo_goat_prior_dim = self.repo_goat_prior.append_prior_subspace( q, k, v, position_ids=position_ids, ) # ── IHA: local sliding-window mask for non-seq-expand path (P=1) ───── # For P>1 (seq-expand): the sliding window is expressed as window_size # in flash_attn_func and handled inside _iha_seq_expand_flash_forward. # This block only applies when P=1 (or use_iha=False) and the standard # HF attention backend is in use. # # Para capas marcadas como 'L' en el pattern (iha_is_local=True): # suma la máscara de ventana deslizante W a la máscara causal existente. # Esto restringe cada query a solo los últimos W tokens, reduciendo el # costo de atención de O(S²) a O(S·W) por head y FLOP-matcheando contra # global attention estándar con el schedule 4L+1G del paper. # # Compatibilidad FlashAttention-2 (paper §2 y §4, Algoritmo 1): # IHA preserva el operador de atención estándar — FlashAttention lo # recibe sin modificaciones. La máscara aditiva -inf es convertida # internamente por el wrapper HF a índices de bloque para flash2. # Para sdpa/eager la suma directa de -inf funciona nativamente. # Con P=1 iha_is_local=False siempre (ningún overhead). # En capas 'G' con iha_global_layers_use_iha=False, self.use_iha=False y # esta rama tampoco se activa: la capa es global estándar, paper-faithful. # flash_attention_2/3 keeps the 2D padding mask contract and receives # the locality constraint through `sliding_window`; eager/sdpa still # need the explicit additive band mask. flash_local_sliding_window = None if self.use_iha and self.iha_is_local and not _iha_seq_expand: _S = q.shape[2] # S real (puede diferir de max_pos con packing) _W = self._resolve_iha_window(_S) if self.config._attn_implementation in {"flash_attention_2", "flash_attention_3"}: flash_local_sliding_window = _W else: _band = self._build_iha_local_mask(_S, _W, q.device, q.dtype) attention_mask = ( _band if attention_mask is None else attention_mask + _band ) # ── Affine-Scaled Attention ─────────────────────────────────────── # Active whenever use_affine_scaled_attention=True, regardless of # attention backend. Two code paths — same math, different execution: # eager : full weight access, attn_weights_pre/post_affine captured. # flash/sdpa: α·backend_out + β·Σ valid V, no weight tensors materialised. alpha = None beta = None use_affine = self.use_affine_scaled_attention if use_affine: alpha = linear_clipping(self.alpha_proj(hidden_states)) # [B, S, H] alpha = alpha.permute(0, 2, 1).unsqueeze(-1) # [B, H, S, 1] N = k.shape[-2] beta = (self.alpha_ma.to(alpha.dtype) - alpha) / max(N, 1) if self.training: with torch.no_grad(): batch_mean = alpha.mean(dim=(0, 2), keepdim=True) self.alpha_ma.copy_( self.affine_momentum * self.alpha_ma + (1.0 - self.affine_momentum) * batch_mean ) # ── IHA: expand alpha/beta to match the attention sequence ──────── # P=1 or no IHA: no expansion needed (shapes already correct). # # seq-expand (P>1): SDPA operates over S·P virtual tokens. # alpha_proj produces one scalar per *original* token per head # (computed from hidden_states before IHA expansion, so shape is # [B, H_q, S, 1]). Each original token's alpha/beta applies # uniformly to all P pseudo-slots — alpha is a property of the # real token, not of the virtual slot. # Expansion: repeat_interleave over the sequence dim (dim=2), NOT # the head dim (dim=1) which was the head-expand approximation. # EMA stats remain on [B, H, S, 1] (pre-expansion) — no change. if _iha_seq_expand: alpha = alpha.repeat_interleave(self.iha_P, dim=2) # [B,H_q,S·P,1] beta = beta.repeat_interleave(self.iha_P, dim=2) # [B,H_q,S·P,1] # Recompute beta with N = S·P for the expanded attention axis. # alpha_ma shape [1,H,1,1] broadcasts correctly against expanded alpha. N_expanded = k.shape[-2] # S·P after IHA expansion beta = (self.alpha_ma.to(alpha.dtype) - alpha) / max(N_expanded, 1) if _iha_seq_expand: # ── IHA seq-expand: bypass HF wrapper, call flash_attn directly ── # The HF flash-attention wrapper (_flash_attention_forward via # ALL_ATTENTION_FUNCTIONS) has three hard blockers for the S·P expanded # sequence: # 1. _upad_input silently truncates K/V to mask length S (line 45-46) # 2. _prepare_from_posids shape-mismatches packed sequences # 3. attention_mask shaped [B,S] is incompatible with S·P tokens # _iha_seq_expand_flash_forward calls flash_attn_func / # flash_attn_varlen_func directly, handling packing, sliding window, # and the Affine-Scaled inline decomposition (α·flash + β·Σ valid V). # alpha/beta are [B, H_q, S·P, 1] at this point (already expanded). attn_out, attn_weights = self._iha_seq_expand_flash_forward( q, k, v, alpha=alpha if use_affine else None, beta=beta if use_affine else None, use_affine=use_affine, dropout=0.0 if not self.training else self.attention_dropout, S_orig=_S_orig, position_ids=position_ids, ) elif use_affine: if self.config._attn_implementation == "eager": # Eager: materialises softmax weights for the affine-scaled path. attn_out, attn_weights = affine_scaled_eager_attention_forward( self, q, k, v, attention_mask, scaling=self.scaling, alpha=alpha, beta=beta, dropout=0.0 if not self.training else self.attention_dropout, **kwargs, ) else: # Flash / SDPA: valid-key value sum, no weight tensors. backend_kwargs = kwargs if flash_local_sliding_window is not None: backend_kwargs = {**kwargs, "sliding_window": flash_local_sliding_window} attn_out, attn_weights = affine_scaled_flash_attention_forward( self, q, k, v, attention_mask, scaling=self.scaling, alpha=alpha, beta=beta, dropout=0.0 if not self.training else self.attention_dropout, **backend_kwargs, ) else: if self.config._attn_implementation == "eager": attn_fn = eager_attention_forward else: attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] backend_kwargs = kwargs if flash_local_sliding_window is not None: backend_kwargs = {**kwargs, "sliding_window": flash_local_sliding_window} attn_out, attn_weights = attn_fn( self, q, k, v, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **backend_kwargs, ) if repo_goat_prior_dim: # Remove the zero-valued prior value channels. The prior has already # affected the attention probabilities through Q/K logits; these extra # output channels carry no value information by construction. attn_out = attn_out[..., : self.head_dim].contiguous() # ── IHA Step 5: collapse P pseudo-slot axis → one output per real token ─ # seq-expand (P>1): _iha_seq_expand_flash_forward returns [B, S·P, H, d]. # _apply_iha_collapse reshapes to [B, S, P, H, d] and applies # R[H, P] via einsum 'hp,bsphd→bshd' → [B, S, H, d]. # This is the paper Alg. 1 Step 7-8 (R ∈ ℝ^{H×P}, collapses P slots). # P=1: not executed — shapes already correct (R would be identity anyway). if self.use_iha and self.iha_P > 1: attn_out = self._apply_iha_collapse(attn_out) attn_out = attn_out.reshape(*input_shape, -1, self.head_dim) if self.use_mea_attention: attn_out = self.mea_output_norm(attn_out) # ── Exclusive Self Attention (position B, pre-routing) ──────────────── # Removes auto-position component before directional routing so that # direction_vecs specialize exclusively in cross-domain interference. # # v_ref is captured post-MEA, post-LUCID — the value representation that # actually participated in SDPA. For seq-expand IHA (P>1) v_ref has # shape [B, H_kv, S·P, d]; after repeat_kv and transpose it becomes # [B, S·P, H_q, d]. We apply the same R collapse as for attn_out so # that both tensors live in the same [B, S, H_q, d] space before the # orthogonal projection. For P=1 or no IHA the block is unchanged. if self.use_xsa and v_ref is not None: v_ref_expanded = repeat_kv(v_ref, self.num_key_value_groups) # [B, H_q, S(·P), d] v_ref_t = v_ref_expanded.transpose(1, 2) # [B, S(·P), H_q, d] if _iha_seq_expand: # seq-expand: collapse [B, S·P, H_q, d] → [B, S, H_q, d] using # the same R[H_q, P] as attn_out, preserving the shared projection # subspace needed for the XSA dot product to be meaningful. v_ref_t = self._apply_iha_collapse(v_ref_t) elif self.use_iha and self.iha_P > 1: # head-expand legacy path (P>1 but seq-expand inactive — kept for # backward compatibility if the architecture is ever reconfigured). v_ref_t = self._apply_iha_collapse(v_ref_t) v_ref_t = v_ref_t.to(attn_out.dtype) proj = (attn_out * v_ref_t).sum(dim=-1, keepdim=True) norm_sq = (v_ref_t * v_ref_t).sum(dim=-1, keepdim=True).clamp(min=self.xsa_eps) xsa_comp = (proj / norm_sq) * v_ref_t attn_out = attn_out - xsa_comp # ── Directional Routing (position C, post-XSA, pre-reshape) ────── # Suppresses cross-domain interference directions from the head output. # Operates on [B, S, H, d_head] before reshape and o_proj. # When use_xsa=False: directions span full head-space (no XSA pre-clean). # When use_directional_routing=False: this block is skipped entirely. if self.use_directional_routing: attn_out = self._apply_directional_routing( attn_out, hidden_states) # ── Reshape → o_proj → Gated Attention gate → dropout ──────────── attn_out_flat = attn_out.reshape(*input_shape, -1).contiguous() attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate)) attn_out_gated = self.dropout(attn_out_gated) return attn_out_gated, attn_weights, current_layer_fan class PolyNorm(nn.Module): def __init__( self, eps: float = 1e-6, proj_eps: float = 1e-6, exclusive_init: float = 0.5, exclusive: bool = True, ): super().__init__() self.weight = nn.Parameter(torch.ones(3) / 3) self.bias = nn.Parameter(torch.zeros(1)) self.eps = eps self.exclusive = exclusive if exclusive: self.proj_eps = proj_eps # Dos fuerzas exclusivas aprendibles en (0, 1), una por rama de orden alto. # Se parametrizan con logits para que sigmoid mantenga alpha ∈ (0, 1). exclusive_init = float(min(max(exclusive_init, 1e-4), 1.0 - 1e-4)) init = torch.full((2,), exclusive_init, dtype=torch.float32) self.exclusive_logits = nn.Parameter(torch.logit(init)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def _exclusive(self, branch, ref, alpha, x1_f, ref_norm_sq): """ Elimina de `branch` la componente alineada con `ref` (la rama lineal x1), ponderada por alpha ∈ (0, 1), y renormaliza el resultado. branch := _norm(branch - alpha · proj_{ref}(branch)) El denominador ref_norm_sq se pasa precalculado para evitar duplicarlo cuando se llama dos veces por forward (una vez para x2, otra para x3). """ branch_f = branch.float() dot = (branch_f * x1_f).sum(dim=-1, keepdim=True) proj_coeff = (dot / ref_norm_sq).to(branch.dtype) out = branch - alpha.to(branch.dtype) * proj_coeff * ref return self._norm(out) def forward( self, x: torch.Tensor, ) -> torch.Tensor: # Caché de potencias: x_sq reutilizado en x1 y x2; x_cu = x·x_sq evita pow(3) x_sq = x.pow(2) x_cu = x * x_sq # Tres ramas normalizadas x1 = x * x_sq.mean(-1, keepdim=True).add(self.eps).rsqrt() x2 = x_sq * (x_sq * x_sq).mean(-1, keepdim=True).add(self.eps).rsqrt() x3 = x_cu * (x_cu * x_cu).mean(-1, keepdim=True).add(self.eps).rsqrt() if self.exclusive: # Fuerzas exclusivas aprendibles alpha2, alpha3 = torch.sigmoid(self.exclusive_logits).unbind() # Precalcular ref (x1) en fp32 y su norma al cuadrado — compartido por x2 y x3 x1_f = x1.float() ref_norm_sq = x1_f.pow(2).sum(-1, keepdim=True).clamp_min(self.proj_eps) # Ortogonalización parcial de las ramas de orden alto respecto a la lineal x2 = self._exclusive(x2, x1, alpha2, x1_f, ref_norm_sq) x3 = self._exclusive(x3, x1, alpha3, x1_f, ref_norm_sq) output = ( self.weight[0] * x3 + self.weight[1] * x2 + self.weight[2] * x1 + self.bias ) return output class NeoLLMMLP(nn.Module): """MLP with FANformer integration and Learnable Multipliers.""" def __init__(self, config): super().__init__() self.fan_layer = FANLayer( hidden_size=config.hidden_size, fan_ratio=getattr(config, "fan_ratio_ffn", 0.0625), ) fan_dim = config.hidden_size + int( config.hidden_size * getattr(config, "fan_ratio_ffn", 0.0625) ) self.gate_proj = LinearWithMultipliers( fan_dim, config.intermediate_size, bias=False, use_row_multiplier=True, use_column_multiplier=False, enable_multipliers=getattr(config, "use_learnable_multipliers", True), ) self.up_proj = nn.Linear(fan_dim, config.intermediate_size, bias=False) self.down_proj = LinearWithMultipliers( config.intermediate_size, config.hidden_size, bias=False, use_row_multiplier=True, use_column_multiplier=True, enable_multipliers=getattr(config, "use_learnable_multipliers", True), ) self.act_fn = PolyNorm(exclusive_init=0.00, exclusive=getattr(config, "polynorm_exclusive", True)) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, x: torch.Tensor, ) -> torch.Tensor: x_fan = self.fan_layer(x) gate_out = self.gate_proj(x_fan) up_out = self.up_proj(x_fan) act_out = self.act_fn(gate_out) act_x_up = act_out * up_out result = self.down_proj(self.dropout(act_x_up)) return result class NeoLLMDecoderLayer(GradientCheckpointingLayer): """ Decoder layer with standard residual connections, optional JTok-M injection. Flow (JTok-M active): 1. ActiveNorm(RMSNorm) → LNS(1/√ℓ) → Attention → residual + GPAS 2. [capture h̃ = hidden after attention for JTok-M router] 3. ActiveNorm(RMSNorm) → LNS(1/√ℓ) → MLP → Δm 4. h^{ℓ+1} = h̃ + Δm + Δr (Δr from JTok-M, scaled 1/√(2ℓ)) 5. GPAS LNS coordination: LNS factor: 1/√ℓ JTok-M factor: 1/√(2ℓ) → ratio = 1/√2 constant at all depths. """ def __init__(self, config: NeoLLMConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.use_jtokm = config.use_jtokm self.use_lns = bool(getattr(config, "use_lns", False)) self.use_gpas = bool(getattr(config, "use_gpas", False)) self.use_siamesenorm = bool(getattr(config, "use_siamesenorm", False)) self.siamese_normalized_input = bool(getattr(config, "siamese_normalized_input", True)) self.siamese_depth_scaling = bool(getattr(config, "siamese_depth_scaling", True)) self.siamese_attn_x_scale_init = float(getattr(config, "siamese_attn_x_scale_init", 1.0)) # Controls only the first pre-attention normalisation applied directly # to the embedding stream. Defaults to True for checkpoint/config # backward compatibility. When False, layer 0 does not instantiate # input_layernorm at all, so the flag removes the corresponding # active norm parameters instead of merely bypassing them in forward. self.use_embedding_input_norm = bool( getattr(config, "use_embedding_input_norm", True) ) self.has_input_layernorm = (not self.use_siamesenorm) and not ( self.layer_idx == 0 and not self.use_embedding_input_norm ) self.self_attn = NeoLLMAttention(config, layer_idx) self.mlp = NeoLLMMLP(config) if self.use_siamesenorm: self.input_layernorm = None self.post_attention_layernorm = None # SiameseNorm is RMS-only by config validation. These modules are # constructed only when the Siamese topology is active, so no # inactive RMSNorm pre-norm modules remain in the graph. self.siamese_attn_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.siamese_attn_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.siamese_mlp_x_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.siamese_mlp_y_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.siamese_attn_input_norm = ( nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.siamese_normalized_input else None ) self.siamese_mlp_input_norm = ( nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.siamese_normalized_input else None ) self.siamese_attn_x_scale = nn.Parameter( torch.full((config.hidden_size,), self.siamese_attn_x_scale_init, dtype=torch.float32) ) else: self.input_layernorm = ( _make_norm( config.hidden_size, eps=config.rms_norm_eps, ) if self.has_input_layernorm else None ) self.post_attention_layernorm = _make_norm( config.hidden_size, eps=config.rms_norm_eps, ) self.siamese_attn_x_norm = None self.siamese_attn_y_norm = None self.siamese_mlp_x_norm = None self.siamese_mlp_y_norm = None self.siamese_attn_input_norm = None self.siamese_mlp_input_norm = None self.siamese_attn_x_scale = None self.lns_attn = LNS(layer_idx) if self.use_lns else None self.lns_mlp = LNS(layer_idx) if self.use_lns else None self.gpas_attn = GPAS(config.hidden_size) if self.use_gpas else None self.gpas_mlp = GPAS(config.hidden_size) if self.use_gpas else None self.current_layer_fan = None # ── StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ──────── # Optional differentiable hidden-state stack inserted between # Transformer layers. The stack is applied before this layer's # attention block, matching the released StackTrans source flow. self.use_stack_memory = getattr(config, "use_stack_memory", False) self.stack_memory = StackMemory(config) if self.use_stack_memory else None if self.use_jtokm: self.jtokm = LeviathanJTokM(config, layer_idx) else: self.jtokm = None # ── Attention Residuals (Kimi Team, 2026) ───────────────────────── # Replaces fixed residual accumulation with learned softmax attention # over preceding layer outputs. Each decoder layer has two learnable # pseudo-queries — one for pre-attention and one for pre-MLP — plus a # shared RMSNorm applied to keys to prevent magnitude-dominated softmax. # Pseudo-queries are initialized to ZERO so AttnRes starts as uniform # average (equivalent to standard residual mean) and training volatility # is avoided. This is critical per the paper's ablation. self.use_attn_res = getattr(config, 'use_attn_res', False) if self.use_attn_res: self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size)) self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size)) self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.attn_res_query_attn = None self.attn_res_query_mlp = None self.attn_res_norm = None # ── LAuReL: Learned Augmented Residual Layer (Menghani et al., ICML 2025) ─ # Generalises the canonical residual connection with learned scalar # weights (RW) and/or a low-rank linear correction (LR). Applied # independently to the attention and MLP sublayers (two residual # junctions per decoder layer). # # LAUREL-RW (use_laurel_rw): # raw scalars α̃, β̃ → softmax([α̃, β̃]) = (α, β) bounded in (0,1) # Residual becomes: α·f(x) + β·x # # LAUREL-LR (use_laurel_lr): # A: nn.Linear(D→r, bias=False) initialised column-orthogonal # B: nn.Linear(r→D, bias=False) initialised to zero # Residual becomes: f(x) + B(A(x)) + x # # LAUREL-RW+LR (both active, paper eq. 5): # Residual becomes: α·f(x) + β·(B(A(x)) + x) # # Mutex with use_attn_res is enforced at config validation time. self.use_laurel = getattr(config, 'use_laurel', False) self.use_laurel_rw = getattr(config, 'use_laurel_rw', True) self.use_laurel_lr = getattr(config, 'use_laurel_lr', False) D = config.hidden_size r = getattr(config, 'laurel_lr_rank', 32) if self.use_laurel and self.use_laurel_rw: # Two raw scalars per sublayer; softmax-normalised in forward. # Initialised to [0, 0] → softmax → (0.5, 0.5) at step 0, # matching a standard equal-weight residual as the starting point. self.laurel_rw_attn = nn.Parameter(torch.zeros(2)) # [α̃_attn, β̃_attn] self.laurel_rw_mlp = nn.Parameter(torch.zeros(2)) # [α̃_mlp, β̃_mlp] else: self.laurel_rw_attn = None self.laurel_rw_mlp = None if self.use_laurel and self.use_laurel_lr: # Attention sublayer low-rank matrices. # A: D×r (projects D→r), initialised column-orthogonal per §3.3. # B: r×D (projects r→D), initialised to zero → identity start. self.laurel_lr_A_attn = nn.Linear(D, r, bias=False) self.laurel_lr_B_attn = nn.Linear(r, D, bias=False) # MLP sublayer low-rank matrices (independent capacity). self.laurel_lr_A_mlp = nn.Linear(D, r, bias=False) self.laurel_lr_B_mlp = nn.Linear(r, D, bias=False) # Initialise: B→zero, A→column-orthogonal (paper footnote 2): # A_{i,j} = 1/√(rD) if i mod r == j else 0 for A_mat in (self.laurel_lr_A_attn, self.laurel_lr_A_mlp): nn.init.zeros_(A_mat.weight) for j in range(r): for i in range(D): if i % r == j: A_mat.weight.data[j, i] = 1.0 / (r * D) ** 0.5 for B_mat in (self.laurel_lr_B_attn, self.laurel_lr_B_mlp): nn.init.zeros_(B_mat.weight) else: self.laurel_lr_A_attn = None self.laurel_lr_B_attn = None self.laurel_lr_A_mlp = None self.laurel_lr_B_mlp = None def apply_stack_memory( self, hidden_states: torch.Tensor, stack_memory: Optional[torch.Tensor], stack_memory_mask: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Apply this layer's differentiable StackMemory module before attention. The memory tensors are owned by ``NeoLLMModel.forward`` and threaded across decoder layers, reproducing the StackTrans pattern where a hidden-state stack sits between standard Transformer layers while the attention implementation remains unchanged. """ if (not self.use_stack_memory) or self.stack_memory is None: return hidden_states, stack_memory, stack_memory_mask if stack_memory is None or stack_memory_mask is None: raise ValueError( "StackMemory is enabled, but stack_memory/stack_memory_mask " "were not initialized by NeoLLMModel.forward." ) return self.stack_memory(hidden_states, stack_memory, stack_memory_mask) def _attn_res( self, sources: list, partial: torch.Tensor, query: torch.Tensor, ) -> torch.Tensor: """ Depth-wise softmax attention over preceding layer outputs. Computes: V = stack(sources + [partial]) [N+1, B, S, D] K = RMSNorm(V) [N+1, B, S, D] logits = query · K [N+1, B, S] weights = softmax(logits, dim=0) [N+1, B, S] h = Σ_n weights_n · V_n [B, S, D] The pseudo-query is shared across positions (per the paper design). RMSNorm on keys prevents layers with large-magnitude outputs from dominating the softmax. Initialized to zero → uniform weights at step 0, reducing to standard residual mean. Args: sources: list of [B, S, D] tensors — completed block summaries or all previous layer outputs (Full AttnRes). partial: [B, S, D] — current intra-block partial sum. query: [D] — learnable pseudo-query for this sublayer. Returns: [B, S, D] — weighted combination of sources + partial. """ all_v = sources + [partial] # list of [B, S, D] V = torch.stack(all_v, dim=0) # [N+1, B, S, D] K = self.attn_res_norm(V) # [N+1, B, S, D] logits = torch.einsum('d,nbsd->nbs', query, K) # [N+1, B, S] weights = torch.softmax(logits, dim=0) # [N+1, B, S] return torch.einsum('nbs,nbsd->bsd', weights, V) # [B, S, D] def _laurel_residual( self, residual: torch.Tensor, delta: torch.Tensor, rw_param: Optional[torch.Tensor], A_mat, B_mat, slot: str = "attn", ) -> torch.Tensor: """ Computes the LAuReL-augmented residual junction (Menghani et al., ICML 2025). Dispatches among three regimes depending on which sub-variants are active: LAUREL-RW only (use_laurel_rw=True, use_laurel_lr=False): α, β = softmax([α̃, β̃]) out = α · delta + β · residual (paper §2.1) LAUREL-LR only (use_laurel_rw=False, use_laurel_lr=True): lr_delta = B(A(residual)) out = delta + lr_delta + residual (paper eq. 3) LAUREL-RW+LR (both active, paper eq. 5): α, β = softmax([α̃, β̃]) lr_delta = B(A(residual)) out = α · delta + β · (lr_delta + residual) In all cases the standard residual identity (out = delta + residual) is recovered at initialisation: RW starts at (α=0.5, β=0.5); LR starts with B=0 so lr_delta=0. Args: residual: [B, S, D] — accumulated residual stream (x_i in the paper). delta: [B, S, D] — sublayer output f(x_i) (attention or MLP). rw_param: Parameter[2] — raw [α̃, β̃] before softmax; None if RW off. A_mat: nn.Linear(D→r, bias=False) — LR down-proj; None if LR off. B_mat: nn.Linear(r→D, bias=False) — LR up-proj; None if LR off. slot: "attn" or "mlp". Returns: [B, S, D] — augmented residual output for the current sublayer. """ has_rw = rw_param is not None has_lr = A_mat is not None # ── LAUREL-LR: low-rank residual correction ──────────────────────── if has_lr: lr_delta = B_mat(A_mat(residual)) # [B, S, D] else: lr_delta = None # ── LAUREL-RW: learned scalar gate ──────────────────────────────── if has_rw: ab = torch.softmax(rw_param.float(), dim=0).to(residual.dtype) alpha = ab[0] beta = ab[1] else: alpha = beta = None # ── Compose output ───────────────────────────────────────────────── if has_rw and has_lr: # LAUREL-RW+LR (paper eq. 5): α·f(x) + β·(BAx + x) return alpha * delta + beta * (lr_delta + residual) elif has_rw: # LAUREL-RW (paper §2.1): α·f(x) + β·x return alpha * delta + beta * residual else: # LAUREL-LR (paper eq. 3): f(x) + BAx + x return delta + lr_delta + residual def _siamese_stream_scale(self, ref: torch.Tensor) -> torch.Tensor: if not self.siamese_depth_scaling: return ref.new_tensor(1.0) return ref.new_tensor(1.0 / math.sqrt(2.0 * float(self.layer_idx + 1))) def forward_siamesenorm( self, x_states: torch.Tensor, y_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, first_layer_fan: Optional[torch.Tensor] = None, z_tilde: Optional[torch.Tensor] = None, B_vals: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple: # SiameseNorm keeps two coupled streams with shared Attention/MLP # parameters. All Siamese normalization modules are RMSNorm by # construction; no dynamic-normalization branch remains. # ── Attention shared block ──────────────────────────────────────── x_attn_norm = self.siamese_attn_x_norm(x_states) y_attn_norm = self.siamese_attn_y_norm(y_states) x_scale = self.siamese_attn_x_scale.to(dtype=x_attn_norm.dtype, device=x_attn_norm.device) h_attn = x_scale * x_attn_norm + y_attn_norm if self.siamese_attn_input_norm is not None: h_attn = self.siamese_attn_input_norm(h_attn) h_lns = self.lns_attn(h_attn) if self.use_lns else h_attn attn_out, attn_weights, self.current_layer_fan = self.self_attn( hidden_states=h_lns, key_value_states=None, attention_mask=attention_mask, position_embeddings=position_embeddings, first_layer_fan=first_layer_fan, repo_rope_args=repo_rope_args, position_ids=position_ids, **kwargs, ) stream_scale = self._siamese_stream_scale(attn_out) x_after_attn = x_states + stream_scale * attn_out y_after_attn = y_states + attn_out if self.use_gpas: x_after_attn = self.gpas_attn(x_after_attn) # ── MLP shared block ────────────────────────────────────────────── x_mlp_norm = self.siamese_mlp_x_norm(x_after_attn) y_mlp_norm = self.siamese_mlp_y_norm(y_after_attn) h_mlp = x_mlp_norm + y_mlp_norm if self.siamese_mlp_input_norm is not None: h_mlp = self.siamese_mlp_input_norm(h_mlp) h_lns2 = self.lns_mlp(h_mlp) if self.use_lns else h_mlp delta_m = self.mlp(h_lns2) shared_update = delta_m aux_stats = None if self.use_jtokm and z_tilde is not None and B_vals is not None: orig_shape = x_after_attn.shape h_flat = x_after_attn.reshape(-1, self.hidden_size) z_flat = z_tilde.reshape(-1, z_tilde.shape[-1]) B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1]) delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat) shared_update = shared_update + delta_r.reshape(orig_shape) x_next = x_after_attn + stream_scale * shared_update y_next = y_after_attn + shared_update if self.use_gpas: x_next = self.gpas_mlp(x_next) outputs = (x_next, y_next) if output_attentions: outputs += (attn_weights,) if aux_stats is not None: outputs += (aux_stats,) return outputs def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, first_layer_fan: Optional[torch.Tensor] = None, z_tilde: Optional[torch.Tensor] = None, B_vals: Optional[torch.Tensor] = None, attn_res_sources: Optional[list] = None, attn_res_partial: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple: # ── Snapshot input ──────────────────────────────────────────────── # ── Attention Residuals: compute pre-attention input ────────────── # When active, the input to the attention sublayer is no longer the # raw hidden_states (accumulated residual) but a softmax-weighted # combination of all previous layer outputs (or block summaries). # attn_res_partial carries the intra-block standard residual that # connects the attention and MLP sublayers within this layer. # When inactive, flow is identical to the original. if self.use_attn_res and attn_res_sources is not None and attn_res_partial is not None: h_attn = self._attn_res( attn_res_sources, attn_res_partial, self.attn_res_query_attn, ) residual_attn = attn_res_partial else: h_attn = hidden_states residual_attn = hidden_states # ── Attention block ─────────────────────────────────────────────── # Optional embedding input normalisation: # layer_idx == 0: h_attn is the embedding stream entering the first # attention block, so the flag can expose raw # embedding magnitudes to attention. # layer_idx > 0: keep the standard pre-norm Transformer flow exactly # as before. if self.input_layernorm is None: h_normed = h_attn else: h_normed = _apply_norm(self.input_layernorm, h_attn) h_lns = self.lns_attn(h_normed) if self.use_lns else h_normed hidden_states, attn_weights, self.current_layer_fan = self.self_attn( hidden_states=h_lns, key_value_states=None, attention_mask=attention_mask, position_embeddings=position_embeddings, first_layer_fan=first_layer_fan, repo_rope_args=repo_rope_args, position_ids=position_ids, **kwargs, ) # ── Residual junction (attention sublayer) ──────────────────────── if self.use_laurel: attn_aug = self._laurel_residual( residual_attn, hidden_states, self.laurel_rw_attn, self.laurel_lr_A_attn, self.laurel_lr_B_attn, slot="attn", ) else: attn_aug = residual_attn + hidden_states h_tilde = self.gpas_attn(attn_aug) if self.use_gpas else attn_aug # ── Attention Residuals: compute pre-MLP input ──────────────────── # After attention, the partial sum is updated with h_tilde. # The pre-MLP AttnRes attends over the same sources but with h_tilde # as the current partial — capturing the within-layer attention output. if self.use_attn_res and attn_res_sources is not None: h_mlp = self._attn_res( attn_res_sources, h_tilde, self.attn_res_query_mlp, ) residual_mlp = h_tilde else: h_mlp = h_tilde residual_mlp = h_tilde # ── MLP block ───────────────────────────────────────────────────── h_normed2 = _apply_norm(self.post_attention_layernorm, h_mlp) h_lns2 = self.lns_mlp(h_normed2) if self.use_lns else h_normed2 delta_m = self.mlp(h_lns2) # ── JTok-M injection (additive alongside MLP residual) ──────────── aux_stats = None # ── Residual junction (MLP sublayer) ───────────────────────────── # LAuReL augments the base MLP residual (residual_mlp + delta_m). # When JTok-M is active, its additive delta_r is summed on top of # the LAuReL output — JTok-M is orthogonal to the residual gate and # always contributes as a plain additive correction. if self.use_laurel: mlp_aug = self._laurel_residual( residual_mlp, delta_m, self.laurel_rw_mlp, self.laurel_lr_A_mlp, self.laurel_lr_B_mlp, slot="mlp", ) else: mlp_aug = residual_mlp + delta_m if self.use_jtokm and z_tilde is not None and B_vals is not None: orig_shape = h_tilde.shape h_flat = h_tilde.reshape(-1, self.hidden_size) z_flat = z_tilde.reshape(-1, z_tilde.shape[-1]) B_flat = B_vals.reshape(-1, B_vals.shape[-2], B_vals.shape[-1]) delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat) delta_r = delta_r.reshape(orig_shape) hidden_states = ( self.gpas_mlp(mlp_aug + delta_r) if self.use_gpas else mlp_aug + delta_r ) else: hidden_states = ( self.gpas_mlp(mlp_aug) if self.use_gpas else mlp_aug ) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) if aux_stats is not None: outputs += (aux_stats,) return outputs class SpellingBeeEmbedding(nn.Module): """ Spelling Bee Embeddings (Rabe et al., 2026, arXiv:2601.18030). Augments token embeddings with character-level information derived from the UTF-8 byte sequence of each token. The spelling bee embedding is the mean of the standard token embedding and a character-level summary: e_bee(t) = 0.5 * (e_tok(t) + e_chars(t)) e_chars(t) = inv_sqrt_len(t) * Σ_{i=0}^{15} RoPE(e_byte[b_i], i) where inv_sqrt_len = 1/√|t| is precomputed per token type at setup time. Key design decisions vs. a naïve per-occurrence implementation: 1. **Vocab-level computation** — e_chars is built over the full vocabulary once per forward (shape [V, d]), then gathered by token_ids. A naïve implementation would compute [B*S, 16, d] per step, repeating identical work for every occurrence of a frequent token. This approach reduces the dominant intermediate from O(B·S·16·d) to O(V·16·d), where V ≪ B·S in practice for most batches. 2. **Static [256, 16, d] rope_bytes table** — RoPE is applied once over all 256 possible byte values at all 16 positions, producing a table with fully static shapes. torch.compile / max_autotune can fuse the construction of this table (two elementwise ops + concat over fixed dims) into a single kernel. Token-level e_chars is then a gather + sum over this table, also fully static. 3. **Precomputed inv_sqrt_lens** — 1/√byte_len is computed once in set_byte_table and stored as a persistent buffer. The per-forward normalisation becomes a single elementwise multiply, with no sqrt or division in the hot path. Compatible with both the standard embed_tokens path and the LeviathanGenerator path. **Inference cost: zero overhead after baking.** Call ``bake_inference_table(token_embeds_weight)`` once after training to collapse the SBE into a single embedding table indistinguishable from a standard nn.Embedding lookup. **Setup: call ``set_byte_table(tokenizer)`` once after model init** (and before any .to(device) / FP8 conversion) before training. The byte table and inv_sqrt_lens are persistent buffers saved in checkpoints. References: Rabe, Clymo & Dong (2026). "Spelling Bee Embeddings for Language Modeling." arXiv:2601.18030. """ MAX_BYTES: int = 16 def __init__(self, config: "NeoLLMConfig"): super().__init__() d = config.hidden_size base = getattr(config, "rope_theta", 10000.0) # Guardado para poder recomputar los buffers RoPE en _reset_rope_buffers. self._rope_base = base # 256 × d byte embedding lookup (one per UTF-8 byte value 0..255). self.byte_emb = nn.Embedding(256, d) # ── Persistent buffers (saved in checkpoints) ───────────────────── # token_bytes [vocab_size, MAX_BYTES]: UTF-8 byte values per token, # padded with 0x00 up to MAX_BYTES positions. self.register_buffer( "token_bytes", torch.zeros(config.vocab_size, self.MAX_BYTES, dtype=torch.long), persistent=True, ) # inv_sqrt_lens [vocab_size]: precomputed 1/sqrt(byte_len) per token. # Replaces the runtime sqrt+division of the naïve implementation. self.register_buffer( "inv_sqrt_lens", torch.ones(config.vocab_size, dtype=torch.float), persistent=True, ) # LayerNorm over character embeddings — mirrors the reference impl # (character_norm=True by default in littletrainingloop). Runs in # float32 for stability, applied at vocab level before the batch gather. self.char_norm = nn.LayerNorm(d) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): """ Sobreescribe la carga de state_dict para eliminar los tres buffers non-persistent (intra_cos, intra_sin, pos_idx) antes de aplicar el state_dict, evitando que versiones anteriores del checkpoint (donde eran persistent=True) sobreescriban los valores correctos calculados en __init__ con valores corruptos del safetensors. from_pretrained de HuggingFace bypasea _register_load_state_dict_pre_hook y carga directamente por nombre, por lo que este override es necesario. """ for key in ("intra_cos", "intra_sin", "pos_idx"): state_dict.pop(prefix + key, None) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def set_byte_table(self, tokenizer) -> None: """ Precompute the UTF-8 byte table and inv_sqrt_lens from a tokenizer. Must be called **once** after model instantiation and **before** ``.to(device)`` / FP8 conversion so the buffers land on the correct device after those transforms. Both buffers are persistent and will be saved/restored from checkpoints automatically. Args: tokenizer: Any HuggingFace tokenizer with ``convert_ids_to_tokens(int) -> str | None``. """ vocab_size = self.token_bytes.shape[0] byte_ids = torch.zeros(vocab_size, self.MAX_BYTES, dtype=torch.long) inv_sqrt = torch.ones(vocab_size, dtype=torch.float) # default 1/√1 for token_id in range(vocab_size): token_str = tokenizer.convert_ids_to_tokens(token_id) if token_str is None: continue # Some tokenizers use a special space character (Ġ / ▁); encode # directly to UTF-8 so byte values match raw text bytes. try: raw = token_str.encode("utf-8") except Exception: raw = b"\x00" n = min(len(raw), self.MAX_BYTES) for i in range(n): byte_ids[token_id, i] = raw[i] inv_sqrt[token_id] = 1.0 / math.sqrt(max(n, 1)) self.token_bytes.copy_(byte_ids.to(self.token_bytes.device)) self.inv_sqrt_lens.copy_(inv_sqrt.to(self.inv_sqrt_lens.device)) # ── Core helpers ────────────────────────────────────────────────────────── def _build_rope_bytes(self) -> torch.Tensor: """ Build the static [256, MAX_BYTES, d] RoPE-encoded byte table. For each of the 256 possible byte values and each of the MAX_BYTES intra-token positions, applies RoPE rotation using the current byte_emb.weight. All shapes are fully static, so torch.compile can fuse this into a single kernel. cos/sin se computan inline aquí en vez de usar buffers registrados. Con device_map + accelerate, from_pretrained materializa tensores del safetensors directamente — non-persistent buffers que no están en el checkpoint quedan como memoria sin inicializar. Computar inline elimina esa dependencia sin overhead apreciable ([16, d//2] es ínfimo). Returns: rope_bytes [256, MAX_BYTES, d] """ w = self.byte_emb.weight # [256, d] half = w.shape[-1] // 2 w1 = w[:, :half].unsqueeze(1) # [256, 1, half] w2 = w[:, half:].unsqueeze(1) # [256, 1, half] # Computar RoPE inline — formas estáticas, torch.compile lo fusiona. theta = 1.0 / (self._rope_base ** ( torch.arange(0, half, dtype=torch.float32, device=w.device) * 2.0 / (half * 2) )) pos = torch.arange(self.MAX_BYTES, dtype=torch.float32, device=w.device) freqs = torch.outer(pos, theta) # [MAX_BYTES, half] cos = freqs.cos().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half] sin = freqs.sin().to(w.dtype).unsqueeze(0) # [1, MAX_BYTES, half] return torch.cat( [w1 * cos - w2 * sin, w1 * sin + w2 * cos], dim=-1, ) # [256, MAX_BYTES, d] # ── Forward ─────────────────────────────────────────────────────────────── def forward( self, token_ids: torch.Tensor, # [B, S] or [N] token_embeds: torch.Tensor, # [B, S, d] or [N, d] ) -> torch.Tensor: """ Args: token_ids: integer token indices to look up byte sequences. token_embeds: embeddings from embed_tokens or LeviathanGenerator. Returns: Spelling bee embeddings — same shape as token_embeds. """ # ── Step 1: build rope_bytes over 256 byte types × 16 positions ─── # Shape [256, MAX_BYTES, d] — fully static, one kernel via compile. rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d] # ── Step 2: build e_chars over vocab types, not occurrences ──────── # token_bytes [V, MAX_BYTES]: byte value at each position per token. # pos_idx [MAX_BYTES]: column selector 0..MAX_BYTES-1. # rope_bytes[token_bytes, pos_idx[None, :], :] selects, for each # vocab token and each position, the RoPE-rotated embedding of that # byte at that position. Result [V, MAX_BYTES, d], then sum → [V, d]. e_chars_vocab = rope_bytes[ self.token_bytes, # [V, MAX_BYTES] — row index torch.arange(self.MAX_BYTES, device=rope_bytes.device), # [MAX_BYTES] — col index ].sum(1) # [V, d] # ── Step 3: apply precomputed 1/√byte_len per vocab type ──────────── # No sqrt or division in the hot path — pure multiply. e_chars_vocab = e_chars_vocab * self.inv_sqrt_lens.unsqueeze(-1) # [V, d] # ── Step 3b: LayerNorm over character embeddings (float32 for stability) # Mirrors reference impl (character_norm=True). Ensures E[‖e_chars‖²] # matches E[‖e_tok‖²] regardless of token byte-length distribution. e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_embeds.dtype) # ── Step 4: gather only the tokens present in this batch ──────────── # This is the only B×S operation — a single embedding lookup. e_chars = e_chars_vocab[token_ids] # [B, S, d] or [N, d] # ── Step 5: mean with token embeddings ────────────────────────────── return (token_embeds + e_chars) * 0.5 # ── Inference utility ───────────────────────────────────────────────────── @torch.no_grad() def bake_inference_table( self, token_emb_weight: torch.Tensor, ) -> torch.Tensor: """ Collapse SBE into a single [vocab_size, d] embedding table. After baking, the SBE computation is indistinguishable from a standard nn.Embedding lookup — zero additional overhead at inference time. Args: token_emb_weight: [vocab_size, d] — weight matrix of embed_tokens or the equivalent table (e.g. after Leviathan). Returns: [vocab_size, d] — baked spelling bee embedding table. Usage:: baked = model.model.spelling_bee.bake_inference_table( model.model.embed_tokens.weight ) model.model.embed_tokens.weight.copy_(baked) # Optionally free byte_emb parameters: # del model.model.spelling_bee """ rope_bytes = self._build_rope_bytes() # [256, MAX_BYTES, d] e_chars_vocab = rope_bytes[ self.token_bytes, torch.arange(self.MAX_BYTES, device=rope_bytes.device), ].sum(1) * self.inv_sqrt_lens.unsqueeze(-1) # [V, d] e_chars_vocab = self.char_norm(e_chars_vocab.to(self.char_norm.weight.dtype)).to(token_emb_weight.dtype) return (token_emb_weight + e_chars_vocab) * 0.5 class NeoLLMPreTrainedModel(PreTrainedModel): """ Base class with custom weight initialization for all NeoLLM components. LeviathanGenerator (real per-head architecture): - codebooks: normal(0, initializer_range) - head_proj[i]: normal(0, initializer_range) — standard for linear - head_norm[i]: weight=1, bias=0 — default LayerNorm init - head_scale: filled with (num_knots - 1) — matches ckhronos.py exactly: scale initialized to stretch the knot grid uniformly so d = |x - grid| * (num_knots-1) maps the [0,1] input to [0, num_knots-1] at init. - head_spline_delta: normal(mean=0.0, std=0.1). The effective coefficient is (1 + delta), matching the Leviathan reference parameterization and keeping the product across d_seed dimensions near 1 at step 0. - head_out[i]: normal(0, initializer_range / sqrt(num_modes)) — scaled by 1/√M so the sum of M head outputs starts with the same variance as a single head projection. - seed_proj_weight/bias: normal(0, initializer_range), bias=0 — JTok-M shared path No W_res — confirmed absent in the authors' implementation. LeviathanJTokM: - spline_coeff: normal(mean=1.0, std=0.1) — same as generator - W_out: normal(0, initializer_range) - W_res: normal(0, initializer_range) - router: parent handles (normal init) - scaler: ones (identity at init) NeoLLMAttention (Affine-Scaled Attention): - alpha_proj: normal(0, 0.02) — near-zero so linear_clipping(≈0) ≈ 0.5 at init, giving a mild ~0.5× scaling of softmax weights per head rather than collapsing to 0 or 1. - alpha_ma: zeros — running EMA starts at 0, β starts as −α/N ≈ small negative offset; model quickly learns to adjust both. REPOModule (Context Re-Positioning): - W_g, W_c, W_z: default normal init from parent _init_weights. No special initialization required — the SwiGLU sub-layer starts near-zero, so z_i ≈ 0 for all tokens at step 0, which is equivalent to constant position assignment (NoPE-like). The model quickly learns to differentiate positions as needed. """ config: NeoLLMConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["NeoLLMDecoderLayer"] _supports_attention_backend = True _supports_flash_attn = True _supports_flash_attn_2 = True _supports_sdpa = True _is_stateful = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, NeoLLMAttention): if getattr(module, "use_fan_residual", False): if hasattr(module, "lambda_1") and module.lambda_1 is not None: module.lambda_1.data.fill_(0.5) if hasattr(module, "lambda_2") and module.lambda_2 is not None: module.lambda_2.data.fill_(0.5) if hasattr(module, "mea_key_mix") and module.mea_key_mix is not None: # Identity initialization: at step 0 MEA behaves as standard attention # and all matrix entries receive gradient immediately from the first step. # For square matrices (normal training case) this is exact identity. # For rectangular matrices (KV compression, h' Tuple: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) if return_dict is None: cfg_dict = vars(self.config) return_dict = cfg_dict.get( "return_dict", cfg_dict.get("use_return_dict", True), ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds") # ── Embedding stage ──────────────────────────────────────────────── z_tilde = None B_vals = None if inputs_embeds is None: if self.config.use_token_generator: if self.config.use_jtokm: # Return internals for reuse by JTok-M surfaces inputs_embeds, z_tilde, B_vals = self.token_generator( input_ids, return_internals=True) # Reshape to [batch, seq, d_seed] and [batch, seq, d_seed, n_knots] z_tilde = z_tilde.reshape(*input_ids.shape, self.config.generator_d_seed) B_vals = B_vals.reshape( *input_ids.shape, self.config.generator_d_seed, self.config.generator_num_knots, ) else: inputs_embeds = self.token_generator(input_ids) else: inputs_embeds = self.embed_tokens(input_ids) # ── Spelling Bee Embeddings (applied post-embedding, pre-decoder) ────── # input_ids may be None when inputs_embeds was passed directly by the # caller; in that case SBE cannot run (no token_ids available) and is # silently skipped — consistent with the standard embedding bypass path. if self.spelling_bee is not None and input_ids is not None: inputs_embeds = self.spelling_bee(input_ids, inputs_embeds) if position_ids is None: position_ids = torch.arange( 0, inputs_embeds.shape[1], device=inputs_embeds.device ).unsqueeze(0) causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=None, position_ids=position_ids, ) hidden_states = inputs_embeds use_siamesenorm = bool(getattr(self.config, "use_siamesenorm", False)) siamese_y_states = inputs_embeds if use_siamesenorm else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_aux_stats = [] # ── StackMemory state ───────────────────────────────────────────── # Same source-level idea as StackTrans: # memory = self.memory.detach()[:b] # memory_mask = self.memory_mask.detach()[:b] # Here the base buffer has batch size 1 and is expanded locally so no # module attribute is created or reassigned inside torch.compile. use_stack_memory = getattr(self.config, "use_stack_memory", False) if use_stack_memory: batch_size = hidden_states.shape[0] stack_memory = self.memory.detach().to(dtype=hidden_states.dtype).expand( batch_size, -1, -1, -1 ) stack_memory_mask = self.memory_mask.detach().expand(batch_size, -1, -1) else: stack_memory = None stack_memory_mask = None position_embeddings = self.rotary_emb(hidden_states, position_ids) self.first_layer_fan = None if getattr(self.config, "use_fan_residual", True) else False # ── REPO: pass inv_freq by reference at forward time ────────────────── # rotary_emb.inv_freq is already on the correct device (managed by # NeoLLMRotaryEmbedding as a buffer) — no .to(), no DeviceCopy op. # Computed once here and passed through the decoder layer chain so # NeoLLMAttention never needs to store it as a buffer itself, avoiding # the meta-tensor issue that occurs when lm_eval calls .to(device). # # Extended for IHA seq-expand (P>1): when use_iha=True and P>1, the # attention layer also needs inv_freq to build interleaved RoPE positions # (s·P + p) for the expanded virtual sequence — even on layers that do not # use REPO. Passing repo_rope_args unconditionally in this case incurs # only a small reference pass; it is ignored by layers where neither REPO # nor IHA seq-expand is active. _iha_needs_inv_freq = ( getattr(self.config, "use_iha", False) and int(getattr(self.config, "iha_num_pseudo_heads", 1)) > 1 ) _repo_grape_needs_inv_freq = bool(getattr(self.config, "use_repo_grape", False)) repo_rope_args = ( (self.rotary_emb.inv_freq, self.rotary_emb.attention_scaling) if (getattr(self.config, "use_repo", False) or _repo_grape_needs_inv_freq or _iha_needs_inv_freq) else None ) # ── Attention Residuals state ────────────────────────────────────── # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per # decoder layer — all previous outputs are kept, max N=num_layers+1. # Block AttnRes (attn_res_num_blocks>0): sources grows by one entry per # block boundary — at most num_blocks+1 entries, far less memory. # In both modes, attn_res_partial is the current intra-block accumulated # hidden state that connects the attn and MLP sublayers and flows between # decoder layers within a block. use_attn_res = getattr(self.config, 'use_attn_res', False) attn_res_sources = None attn_res_partial = None if use_attn_res: attn_res_sources = [hidden_states] # b_0 = token embedding attn_res_partial = hidden_states # initial partial sum num_blocks = getattr(self.config, 'attn_res_num_blocks', 0) block_size = ( max(self.config.num_hidden_layers // num_blocks, 1) if num_blocks > 0 else 1 # Full AttnRes: every layer is its own "block" ) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # ── Block AttnRes: boundary handling ────────────────────────── # At each block boundary (excluding layer 0): append the current # partial sum to sources as a completed block summary, then reset # partial to None so the new block builds from scratch — matching # the paper's pseudocode exactly. # For Full AttnRes (block_size=1): every layer is a boundary, so # partial is appended and reset after every layer. The partial is # re-seeded from the previous hidden_states below. if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0: attn_res_sources = attn_res_sources + [attn_res_partial] attn_res_partial = hidden_states # start new block from current output if use_stack_memory: hidden_states, stack_memory, stack_memory_mask = decoder_layer.apply_stack_memory( hidden_states, stack_memory, stack_memory_mask ) if use_siamesenorm: layer_outputs = decoder_layer.forward_siamesenorm( hidden_states, siamese_y_states, position_embeddings=position_embeddings, attention_mask=causal_mask, first_layer_fan=self.first_layer_fan, z_tilde=z_tilde, B_vals=B_vals, output_attentions=output_attentions, repo_rope_args=repo_rope_args, position_ids=position_ids, **kwargs, ) hidden_states = layer_outputs[0] siamese_y_states = layer_outputs[1] extras_start = 2 else: layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask, first_layer_fan=self.first_layer_fan, z_tilde=z_tilde, B_vals=B_vals, attn_res_sources=attn_res_sources, attn_res_partial=attn_res_partial if use_attn_res else None, output_attentions=output_attentions, repo_rope_args=repo_rope_args, position_ids=position_ids, **kwargs, ) hidden_states = layer_outputs[0] extras_start = 1 # Update AttnRes partial sum — the new partial is the layer output if use_attn_res: attn_res_partial = hidden_states if output_attentions: all_attentions = all_attentions + (layer_outputs[extras_start],) extras_start += 1 # Collect JTok-M aux stats. if self.config.use_jtokm: for item in layer_outputs[extras_start:]: if isinstance(item, tuple) and len(item) == 3: all_aux_stats.append(item) break if (getattr(self.config, "use_fan_residual", True) and self.first_layer_fan is None and hasattr(decoder_layer, "current_layer_fan")): self.first_layer_fan = decoder_layer.current_layer_fan if use_siamesenorm: hidden_states = self.siamese_final_norm( self.siamese_x_final_norm(hidden_states) + self.siamese_y_final_norm(siamese_y_states) ) else: hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, None, all_hidden_states, all_attentions] if v is not None ) + (all_aux_stats,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=None, hidden_states=all_hidden_states, attentions=all_attentions, ), all_aux_stats @torch.compiler.disable def compute_cce_loss( hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None ): """CCE loss excluded from torch.compile.""" processed_labels = labels.to(hidden_states.device) if pad_token_id is not None: processed_labels = torch.where( processed_labels == pad_token_id, torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device), processed_labels, ) return linear_cross_entropy( hidden_states, lm_head_weight, processed_labels, bias=lm_head_bias, shift=1, impl="cce_kahan_full_c", reduction="mean", ) class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin): """ Causal LM with NeoLLM backbone. When use_jtokm=True, the load-balancing auxiliary loss is computed from the per-layer JTok-M routing statistics and added to the cross-entropy loss: total_loss = CE_loss + L_aux where L_aux = λ · n_e · (1/L) · Σ_ℓ Σ_i p_i^ℓ · f_i^ℓ """ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: NeoLLMConfig): super().__init__(config) self.model = NeoLLMModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.use_token_generator: self._tied_weights_keys = {} self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values=None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, ) -> dict: """ NeoLLM does not implement KV caching (always returns past_key_values=None). Transformers' default GenerationMixin.prepare_inputs_for_generation assumes KV cache is active and slices input_ids to only the newest token on every step past the prefill. Without a real cache that retains previous K/V states, the attention module only sees 1 key/value pair while the causal mask still spans the full context length — causing a shape mismatch in SDPA. This override always forwards the COMPLETE input_ids sequence so the model can recompute attention over the full context from scratch at every step. Generation is therefore slower (no caching benefit) but numerically correct. """ model_inputs: dict = {"input_ids": input_ids, "attention_mask": attention_mask} if inputs_embeds is not None and past_key_values is None: model_inputs["inputs_embeds"] = inputs_embeds return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: model_out = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) # Unpack: model returns (BaseModelOutputWithPast, aux_stats_list) # or a tuple when return_dict=False if isinstance(model_out, tuple): outputs, all_aux_stats = model_out[0], model_out[-1] if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs.last_hidden_state else: outputs = model_out all_aux_stats = [] hidden_states = outputs.last_hidden_state loss = None if labels is not None: loss = compute_cce_loss( hidden_states, labels, self.lm_head.weight, getattr(self.lm_head, "bias", None), self.config.pad_token_id, ) # Add JTok-M load-balancing auxiliary loss if self.config.use_jtokm and all_aux_stats: aux_loss = compute_jtokm_aux_loss( all_aux_stats, n_e=self.config.jtokm_num_experts, weight=self.config.jtokm_aux_loss_weight, ) loss = loss + aux_loss logits = None else: slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, attentions=outputs.attentions if hasattr(outputs, "attentions") else None, ) # ==================== AUTOMODEL REGISTRATION ==================== __all__ = [ "NeoLLMForCausalLM", "NeoLLMModel", "NeoLLMPreTrainedModel", "NeoLLMConfig", "LeviathanGenerator", "LeviathanJTokM", "SpellingBeeEmbedding", "StackMemory", "FANLayer", "ScalarMultiplier", "VectorMultiplier", "LinearWithMultipliers", "MEAHeadRMSNorm", "HadamardOProj", "REPOModule", "RepoGrapePositioning", "RepoGoatPrior", ] AutoConfig.register("neollm", NeoLLMConfig) AutoModel.register(NeoLLMConfig, NeoLLMModel) AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)