"""Custom ANEMLL conversion that takes inputs_embeds instead of input_ids. Required for Mega-ASR: at inference we scatter audio encoder outputs at <|audio_pad|> positions BEFORE the LLM, then feed pre-embedded hidden_states to the decoder. The default ANEMLL conversion has embed_tokens baked in (takes input_ids); we need it bypassed. This script: 1. Loads QwenForCausalLM via ANEMLL's loader 2. Monkey-patches QwenModel.forward to accept an optional inputs_embeds arg 3. Defines a fresh Wrapper that exposes inputs_embeds as the first input 4. Traces + converts via ct.convert with LUT-4 palettization in postprocess 5. Saves the resulting .mlpackage Reuses ANEMLL's QwenConverter postprocessing (LUT-4 quantization, state declarations) by calling its methods after the inputs are swapped. """ from __future__ import annotations import argparse import os import sys from pathlib import Path sys.path.insert(0, "/tmp/Anemll") import numpy as np import torch import torch.nn as nn import coremltools as ct import coremltools.optimize as cto # Apply the local coremltools _cast patch we made earlier (now resident in the # env's installed file; nothing to do here, just import). def patch_qwen_for_inputs_embeds(): """Monkey-patch QwenModel.forward + QwenForCausalLM.forward to accept inputs_embeds. When the caller passes a float tensor in the input_ids slot, treat it as pre-embedded hidden_states and skip embed_tokens. Also relax the strict 2D shape assert in QwenForCausalLM. """ from anemll.models import qwen_model as qm orig_model_forward = qm.QwenModel.forward def model_forward_or_embeds( self, input_ids, causal_mask, position_ids, current_pos, IN_PREFILL: bool = False, ): if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16): hidden_states = input_ids if IN_PREFILL: rotary_emb = self.get_rotary_embedding_prefill(position_ids) else: rotary_emb = self.get_rotary_embeddings_s(current_pos) hidden_states = self.process_layers( hidden_states, position_ids, causal_mask, current_pos, rotary_emb, start_layer=0, end_layer=None, IN_PREFILL=IN_PREFILL, ) hidden_states = self.norm(hidden_states) return hidden_states return orig_model_forward(self, input_ids, causal_mask, position_ids, current_pos, IN_PREFILL=IN_PREFILL) qm.QwenModel.forward = model_forward_or_embeds # Also patch QwenForCausalLM.forward — it asserts input_ids must be 2D # (line 1050 in qwen_model.py). For inputs_embeds (3D), skip that. orig_causal_forward = qm.QwenForCausalLM.forward def causal_forward_or_embeds( self, input_ids, update_mask, position_ids, causal_mask, current_pos, IN_PREFILL: bool = False, ): if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16): # Pre-embedded path — call QwenModel directly, bypass the 2D assert hidden_states = self.model( input_ids, causal_mask, position_ids, current_pos, IN_PREFILL=IN_PREFILL, ) # Replicate the lm-head projection logic from the original forward # (single-token decode case) if not IN_PREFILL and current_pos is not None: seq_len = hidden_states.shape[1] if seq_len == 1: pos_tensor = torch.tensor([0], device=hidden_states.device, dtype=torch.long) else: if isinstance(current_pos, torch.Tensor): pos_tensor = current_pos if current_pos.dim() > 0 else current_pos.unsqueeze(0) else: pos_tensor = torch.tensor([current_pos], device=hidden_states.device, dtype=torch.long) hidden_states = torch.index_select(hidden_states, dim=1, index=pos_tensor) # Use the same Conv2d / 16-way split as the original hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(2).to(qm.MODEL_DTYPE) outs = tuple( getattr(self, f"lm_head16_{k}")(hidden_states).squeeze(2).transpose(1, 2) for k in range(1, 17) ) return outs return orig_causal_forward( self, input_ids, update_mask, position_ids, causal_mask, current_pos, IN_PREFILL=IN_PREFILL, ) qm.QwenForCausalLM.forward = causal_forward_or_embeds print("[patch] QwenModel + QwenForCausalLM now accept float inputs_embeds") def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True, type=Path) ap.add_argument("--output", required=True, type=Path, help="Output .mlpackage path") ap.add_argument("--lut", type=int, default=4) ap.add_argument("--per-channel", type=int, default=8) ap.add_argument("--context-length", type=int, default=512) ap.add_argument("--hidden-size", type=int, default=2048) args = ap.parse_args() patch_qwen_for_inputs_embeds() from anemll.models.qwen_model import ( QwenForCausalLM, QwenConfig, MODEL_DTYPE, TEST_DEVICE, ) from anemll.ane_converter import qwen_converter as qc # Force CoreML mode flags import anemll.models.qwen_model as qm qm.ENABLE_COREML = True # Load config + model import json cfg = json.load(open(args.model / "config.json")) cfg["context_length"] = args.context_length cfg["state_length"] = args.context_length config = QwenConfig(**cfg) model = QwenForCausalLM(config, enable_coreml=True) model.load_pretrained_weights(str(args.model)) model.eval() for p in model.parameters(): p.requires_grad = False print(f"Model loaded: hidden={config.hidden_size}, layers={config.num_hidden_layers}") # Custom wrapper taking inputs_embeds class WrapperEmbeds(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, inputs_embeds, position_ids, causal_mask, current_pos, update_mask): return self.model( input_ids=inputs_embeds, # float tensor → triggers the patched path update_mask=update_mask, position_ids=position_ids, causal_mask=causal_mask, current_pos=current_pos, IN_PREFILL=False, ) wrapper = WrapperEmbeds(model).eval() # Build sample inputs for tracing sample_inputs_embeds = torch.zeros( (1, 1, config.hidden_size), dtype=torch.float16, device=TEST_DEVICE, ) sample_position_ids = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE) sample_causal_mask = torch.zeros( (1, 1, 1, args.context_length), dtype=torch.float16, device=TEST_DEVICE, ) sample_current_pos = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE) sample_update_mask = torch.zeros( (1, 1, args.context_length, 1), dtype=torch.float16, device=TEST_DEVICE, ) print("Tracing ...") traced = torch.jit.trace( wrapper, (sample_inputs_embeds, sample_position_ids, sample_causal_mask, sample_current_pos, sample_update_mask), ) print("Trace done. Converting to CoreML (fp16) ...") # ANEMLL declares the KV cache as a state via GetTransformerStates states = qc.QwenConverter.GetTransformerStates(model, prefix="model.model.") mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="inputs_embeds", shape=sample_inputs_embeds.shape, dtype=np.float16), ct.TensorType(name="position_ids", shape=sample_position_ids.shape, dtype=np.int32), ct.TensorType(name="causal_mask", shape=sample_causal_mask.shape, dtype=np.float16), ct.TensorType(name="current_pos", shape=sample_current_pos.shape, dtype=np.int32), ct.TensorType(name="update_mask", shape=sample_update_mask.shape, dtype=np.float16), ], outputs=[ct.TensorType(name=f"logits{i+1}", dtype=np.float16) for i in range(16)], states=states, minimum_deployment_target=ct.target.iOS18, # fp32 compute (activations) — fp16 overflows in Qwen3-ASR's RMSNorm/attention. # Matches aoiandroid's finding for the same base model. compute_precision=ct.precision.FLOAT32, compute_units=ct.ComputeUnit.CPU_AND_NE, convert_to="mlprogram", skip_model_load=True, ) if args.lut and args.lut < 16: print(f"Applying LUT-{args.lut} palettization (per_channel={args.per_channel}) ...") config_palette = cto.coreml.OpPalettizerConfig( nbits=args.lut, mode="kmeans", granularity="per_grouped_channel", group_size=args.per_channel, ) pal_config = cto.coreml.OptimizationConfig(global_config=config_palette) mlmodel = cto.coreml.palettize_weights(mlmodel, pal_config) args.output.parent.mkdir(parents=True, exist_ok=True) mlmodel.save(str(args.output)) print(f"Saved: {args.output}") if __name__ == "__main__": main()