| import os |
|
|
| |
| SPACE_MODE = os.getenv("SPACE_MODE") |
| if SPACE_MODE is None: |
| try: |
| import jax |
| SPACE_MODE = "serve" if any(getattr(d, "platform", "") in ("gpu","cuda","rocm") for d in jax.devices()) else "template" |
| except Exception: |
| SPACE_MODE = "template" |
|
|
|
|
|
|
| if SPACE_MODE != "serve": |
| |
| os.environ.setdefault("JAX_PLATFORMS", "cpu") |
| else: |
| |
| os.environ.setdefault( |
| "XLA_FLAGS", |
| " ".join([ |
| "--xla_gpu_enable_triton_gemm=true", |
| "--xla_gpu_enable_latency_hiding_scheduler=true", |
| "--xla_gpu_autotune_level=2", |
| ]) |
| ) |
|
|
| |
| os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax") |
|
|
| import jax |
| |
| |
| try: |
| jax.config.update("jax_default_matmul_precision", "tensorfloat32") |
| except Exception: |
| jax.config.update("jax_default_matmul_precision", "high") |
|
|
| |
| try: |
| from jax.experimental.compilation_cache import compilation_cache as cc |
| cc.initialize_cache(os.environ["JAX_CACHE_DIR"]) |
| except Exception: |
| pass |
| |
|
|
|
|
|
|
| from magenta_rt import system, audio as au |
| import numpy as np |
| from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query |
| from fastapi.responses import JSONResponse, FileResponse, HTMLResponse |
| import tempfile, io, base64, math, threading |
| from fastapi.middleware.cors import CORSMiddleware |
| from contextlib import contextmanager |
| import soundfile as sf |
| from math import gcd |
| from scipy.signal import resample_poly |
| from utils import ( |
| match_loudness_to_reference, stitch_generated, hard_trim_seconds, |
| apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail, |
| resample_and_snap, wav_bytes_base64 |
| ) |
|
|
| from jam_worker import JamWorker, JamParams, JamChunk |
| from one_shot_generation import generate_loop_continuation_with_mrt, generate_style_only_with_mrt |
|
|
| import uuid, threading |
|
|
| import logging |
|
|
| import gradio as gr |
| from typing import Optional, Union, Literal |
|
|
|
|
| import json, asyncio, base64 |
| import time |
|
|
|
|
|
|
| from starlette.websockets import WebSocketState |
| try: |
| from uvicorn.protocols.utils import ClientDisconnected |
| except Exception: |
| class ClientDisconnected(Exception): |
| pass |
|
|
| import re, tarfile |
| from pathlib import Path |
| from huggingface_hub import snapshot_download, HfApi |
|
|
| from pydantic import BaseModel |
|
|
| from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect |
|
|
| def _gpu_probe() -> dict: |
| """ |
| Returns: |
| { |
| "ok": bool, |
| "backend": str | None, # "gpu" | "cpu" | "tpu" | None |
| "has_gpu": bool, |
| "devices": list[str], # e.g. ["gpu:0", "gpu:1"] |
| "error": str | None, |
| } |
| """ |
| try: |
| import jax |
| try: |
| backend = jax.default_backend() |
| except Exception: |
| from jax.lib import xla_bridge |
| backend = getattr(xla_bridge.get_backend(), "platform", None) |
|
|
| try: |
| devices = jax.devices() |
| has_gpu = any(getattr(d, "platform", "") in ("gpu", "cuda", "rocm") for d in devices) |
| dev_list = [f"{getattr(d, 'platform', '?')}:{getattr(d, 'id', '?')}" for d in devices] |
| return {"ok": True, "backend": backend, "has_gpu": has_gpu, "devices": dev_list, "error": None} |
| except Exception as e: |
| return {"ok": False, "backend": backend, "has_gpu": False, "devices": [], "error": f"jax.devices failed: {e}"} |
| except Exception as e: |
| return {"ok": False, "backend": None, "has_gpu": False, "devices": [], "error": f"jax import failed: {e}"} |
|
|
| |
| |
| _ASSETS_REPO_ID: str | None = None |
| _MEAN_EMBED: np.ndarray | None = None |
| _CENTROIDS: np.ndarray | None = None |
|
|
| |
|
|
| |
| asset_manager = AssetManager() |
| model_selector = ModelSelector(CheckpointManager(), asset_manager) |
|
|
| def _sync_assets_globals_from_manager(): |
| |
| global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID |
| _MEAN_EMBED = asset_manager.mean_embed |
| _CENTROIDS = asset_manager.centroids |
| _ASSETS_REPO_ID = asset_manager.assets_repo_id |
|
|
| def _any_jam_running() -> bool: |
| with jam_lock: |
| return any(info['worker'].is_alive() for info in jam_registry.values()) |
|
|
| def _stop_all_jams(timeout: float = 5.0): |
| with jam_lock: |
| for sid, info in list(jam_registry.items()): |
| w = info['worker'] |
| if w.is_alive(): |
| w.stop() |
| w.join(timeout=timeout) |
| |
| if info.get('mrt_index') is not None: |
| release_mrt(info['mrt_index']) |
| jam_registry.pop(sid, None) |
|
|
|
|
| async def send_json_safe(ws: WebSocket, obj) -> bool: |
| """Try to send. Returns False if the socket is (or becomes) closed.""" |
| if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED: |
| return False |
| try: |
| await ws.send_text(json.dumps(obj)) |
| return True |
| except (WebSocketDisconnect, ClientDisconnected, RuntimeError): |
| return False |
| except Exception: |
| return False |
|
|
| |
| def _patch_t5x_for_gpu_coords(): |
| try: |
| import jax |
| from t5x import partitioning as _t5x_part |
|
|
| old_bounds = getattr(_t5x_part, "bounds_from_last_device", None) |
| old_getcoords = getattr(_t5x_part, "get_coords", None) |
|
|
| def _bounds_from_last_device_gpu_safe(last_device): |
| |
| core = getattr(last_device, "core_on_chip", None) |
| coords = getattr(last_device, "coords", None) |
| if coords is not None and core is not None: |
| x, y, z = coords |
| return x + 1, y + 1, z + 1, core + 1 |
| |
| return jax.host_count(), jax.local_device_count() |
|
|
| def _get_coords_gpu_safe(device): |
| core = getattr(device, "core_on_chip", None) |
| coords = getattr(device, "coords", None) |
| if coords is not None and core is not None: |
| return (*coords, core) |
| |
| return (device.process_index, device.id % jax.local_device_count()) |
|
|
| _t5x_part.bounds_from_last_device = _bounds_from_last_device_gpu_safe |
| _t5x_part.get_coords = _get_coords_gpu_safe |
| import logging; logging.info("Patched t5x.partitioning for GPU coords without core_on_chip.") |
| except Exception as e: |
| import logging; logging.exception("t5x GPU-coords patch failed: %s", e) |
|
|
| |
| _patch_t5x_for_gpu_coords() |
|
|
| jam_registry: dict[str, dict] = {} |
| jam_lock = threading.Lock() |
|
|
| |
| |
| |
|
|
| class GlobalGenParams: |
| """Global defaults for temperature, topk, guidance_weight. |
| Applied at MRT initialization. Changes require pool restart.""" |
|
|
| def __init__(self): |
| self._lock = threading.RLock() |
| self.temperature = 1.1 |
| self.topk = 40 |
| self.guidance_weight = 1.1 |
|
|
| def get(self): |
| with self._lock: |
| return { |
| 'temperature': self.temperature, |
| 'topk': self.topk, |
| 'guidance_weight': self.guidance_weight |
| } |
|
|
| def update(self, temperature=None, topk=None, guidance_weight=None): |
| """Update requires MRT pool restart to take effect""" |
| with self._lock: |
| if temperature is not None: |
| self.temperature = float(temperature) |
| if topk is not None: |
| self.topk = int(topk) |
| if guidance_weight is not None: |
| self.guidance_weight = float(guidance_weight) |
| return self.get() |
|
|
| _GLOBAL_GEN_PARAMS = GlobalGenParams() |
|
|
| |
| |
| |
|
|
| _MRT_POOL = [] |
| _MRT_POOL_LOCK = threading.Lock() |
| _MRT_AVAILABLE = [] |
| _POOL_INITIALIZED = False |
| _POOL_INIT_LOCK = threading.Lock() |
|
|
| def init_mrt_pool(pool_size=2): |
| """Initialize MRT pool with global params""" |
| global _MRT_POOL, _MRT_AVAILABLE |
|
|
| defaults = _GLOBAL_GEN_PARAMS.get() |
|
|
| _MRT_POOL.clear() |
| _MRT_AVAILABLE.clear() |
|
|
| for i in range(pool_size): |
| ckpt_dir = CheckpointManager.resolve_checkpoint_dir() |
| mrt = system.MagentaRT( |
| tag=os.getenv("MRT_SIZE", "large"), |
| guidance_weight=defaults['guidance_weight'], |
| device="gpu", |
| checkpoint_dir=ckpt_dir, |
| lazy=True |
| ) |
| |
| mrt.temperature = defaults['temperature'] |
| mrt.topk = defaults['topk'] |
|
|
| |
| if asset_manager.mean_embed is None and asset_manager.centroids is None: |
| repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO") |
| if repo: |
| asset_manager.load_finetune_assets_from_hf(repo, None) |
| _sync_assets_globals_from_manager() |
|
|
| _MRT_POOL.append(mrt) |
| _MRT_AVAILABLE.append(True) |
|
|
| def ensure_pool_initialized(): |
| """Lazy init pool on first request""" |
| global _POOL_INITIALIZED |
| if not _POOL_INITIALIZED: |
| with _POOL_INIT_LOCK: |
| if not _POOL_INITIALIZED: |
| init_mrt_pool(pool_size=2) |
| _POOL_INITIALIZED = True |
|
|
| def get_available_mrt(): |
| """Get an available MRT from pool. Returns (index, mrt) or (None, None)""" |
| with _MRT_POOL_LOCK: |
| for i, available in enumerate(_MRT_AVAILABLE): |
| if available: |
| _MRT_AVAILABLE[i] = False |
| return (i, _MRT_POOL[i]) |
| return (None, None) |
|
|
| def release_mrt(index: int): |
| """Release MRT back to pool""" |
| with _MRT_POOL_LOCK: |
| if 0 <= index < len(_MRT_AVAILABLE): |
| _MRT_AVAILABLE[index] = True |
|
|
| def reset_mrt_pool(): |
| """Recreate pool with current global params (requires stopping all sessions)""" |
| global _POOL_INITIALIZED |
|
|
| with _POOL_INIT_LOCK: |
| with _MRT_POOL_LOCK: |
| init_mrt_pool(pool_size=2) |
| _POOL_INITIALIZED = True |
|
|
| |
| |
| |
|
|
| _MRT = None |
| _MRT_LOCK = threading.Lock() |
|
|
| @contextmanager |
| def mrt_overrides(mrt, **kwargs): |
| """Temporarily set attributes on MRT if they exist; restore after.""" |
| old = {} |
| try: |
| for k, v in kwargs.items(): |
| if hasattr(mrt, k): |
| old[k] = getattr(mrt, k) |
| setattr(mrt, k, v) |
| yield |
| finally: |
| for k, v in old.items(): |
| setattr(mrt, k, v) |
|
|
| |
| try: |
| import pyloudnorm as pyln |
| _HAS_LOUDNORM = True |
| except Exception: |
| _HAS_LOUDNORM = False |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def build_style_vector( |
| mrt, |
| *, |
| text_styles: list[str] | None = None, |
| text_weights: list[float] | None = None, |
| loop_embed: np.ndarray | None = None, |
| loop_weight: float | None = None, |
| mean_weight: float | None = None, |
| centroid_weights: list[float] | None = None, |
| ) -> np.ndarray: |
| """ |
| Returns a single style embedding combining: |
| - loop embedding (optional) |
| - one or more text style embeddings (optional) |
| - mean finetune embedding (optional) |
| - centroid embeddings (optional) |
| All weights are normalized so they sum to 1 if > 0. |
| """ |
| comps: list[np.ndarray] = [] |
| weights: list[float] = [] |
|
|
| |
| if loop_embed is not None and (loop_weight or 0) > 0: |
| comps.append(loop_embed.astype(np.float32, copy=False)) |
| weights.append(float(loop_weight)) |
|
|
| |
| if text_styles: |
| for i, s in enumerate(text_styles): |
| s = s.strip() |
| if not s: |
| continue |
| w = 1.0 |
| if text_weights and i < len(text_weights): |
| try: w = float(text_weights[i]) |
| except: w = 1.0 |
| if w <= 0: |
| continue |
| e = mrt.embed_style(s) |
| comps.append(e.astype(np.float32, copy=False)) |
| weights.append(w) |
|
|
| |
| if mean_weight and (_MEAN_EMBED is not None) and mean_weight > 0: |
| comps.append(_MEAN_EMBED) |
| weights.append(float(mean_weight)) |
|
|
| |
| if centroid_weights and _CENTROIDS is not None: |
| K = _CENTROIDS.shape[0] |
| for k, w in enumerate(centroid_weights[:K]): |
| try: w = float(w) |
| except: w = 0.0 |
| if w <= 0: |
| continue |
| comps.append(_CENTROIDS[k]) |
| weights.append(w) |
|
|
| if not comps: |
| |
| return mrt.embed_style("") |
|
|
| wsum = sum(weights) |
| if wsum <= 0: |
| return mrt.embed_style("") |
| weights = [w/wsum for w in weights] |
|
|
| |
| out = np.zeros_like(comps[0], dtype=np.float32) |
| for w, e in zip(weights, comps): |
| out += w * e.astype(np.float32, copy=False) |
| return out |
|
|
|
|
|
|
| |
| |
| |
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| _MRT = None |
| _MRT_LOCK = threading.Lock() |
|
|
| def get_mrt(): |
| global _MRT |
| if _MRT is None: |
| with _MRT_LOCK: |
| if _MRT is None: |
| ckpt_dir = CheckpointManager.resolve_checkpoint_dir() |
| _MRT = system.MagentaRT( |
| tag=os.getenv("MRT_SIZE", "large"), |
| guidance_weight=5.0, |
| device="gpu", |
| checkpoint_dir=ckpt_dir, |
| lazy=True |
| ) |
| |
| if asset_manager.mean_embed is None and asset_manager.centroids is None: |
| repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO") |
| if repo: |
| asset_manager.load_finetune_assets_from_hf(repo, None) |
| _sync_assets_globals_from_manager() |
| return _MRT |
|
|
| _WARMED = False |
| _WARMUP_LOCK = threading.Lock() |
|
|
| def _mrt_warmup(): |
| """ |
| Build a minimal, bar-aligned silent context and run one 2s generate_chunk |
| to trigger XLA JIT & autotune so first real request is fast. |
| """ |
| global _WARMED |
| with _WARMUP_LOCK: |
| if _WARMED: |
| return |
| try: |
| mrt = get_mrt() |
|
|
| |
| codec_fps = float(mrt.codec.frame_rate) |
| ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
| sr = int(mrt.sample_rate) |
|
|
| |
| bpm = 120.0 |
| beats_per_bar = 4 |
|
|
| |
| samples = int(max(1, round(ctx_seconds * sr))) |
| silent = np.zeros((samples, 2), dtype=np.float32) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| sf.write(tmp.name, silent, sr, subtype="PCM_16") |
| tmp_path = tmp.name |
|
|
| try: |
| |
| loop = au.Waveform.from_file(tmp_path).resample(sr).as_stereo() |
| seconds_per_bar = beats_per_bar * (60.0 / bpm) |
| ctx_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) |
|
|
| |
| tokens_full = mrt.codec.encode(ctx_tail).astype(np.int32) |
| tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] |
| context_tokens = make_bar_aligned_context( |
| tokens, |
| bpm=bpm, |
| fps=float(mrt.codec.frame_rate), |
| ctx_frames=mrt.config.context_length_frames, |
| beats_per_bar=beats_per_bar, |
| ) |
|
|
| |
| state = mrt.init_state() |
| state.context_tokens = context_tokens |
| style_vec = mrt.embed_style("warmup") |
|
|
| |
| _wav, _state = mrt.generate_chunk(state=state, style=style_vec) |
|
|
| logging.info("MagentaRT warmup complete.") |
| finally: |
| try: |
| os.unlink(tmp_path) |
| except Exception: |
| pass |
|
|
| _WARMED = True |
| except Exception as e: |
| |
| logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e) |
|
|
|
|
| |
| |
| |
|
|
| @app.on_event("startup") |
| def _boot(): |
| |
| repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO") |
| if repo: |
| ok, msg = asset_manager.load_finetune_assets_from_hf(repo, None) |
| _sync_assets_globals_from_manager() |
| logging.info("Startup asset load from %s: %s", repo, "ok" if ok else msg) |
| else: |
| logging.info("Startup asset load: no repo env set; skipping.") |
|
|
| |
| if os.getenv("MRT_WARMUP", "1") != "0": |
| threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start() |
|
|
| @app.get("/model/status") |
| def model_status(): |
| mrt = get_mrt() |
| return { |
| "tag": getattr(mrt, "_tag", "unknown"), |
| "using_checkpoint_dir": True, |
| "codec_frame_rate": float(mrt.codec.frame_rate), |
| "decoder_rvq_depth": int(mrt.config.decoder_codec_rvq_depth), |
| "context_seconds": float(mrt.config.context_length), |
| "chunk_seconds": float(mrt.config.chunk_length), |
| "crossfade_seconds": float(mrt.config.crossfade_length), |
| "selected_step": os.getenv("MRT_CKPT_STEP"), |
| "repo": os.getenv("MRT_CKPT_REPO"), |
| } |
|
|
| @app.post("/model/swap") |
| def model_swap(step: int = Form(...)): |
| |
| os.environ["MRT_CKPT_STEP"] = str(step) |
| global _MRT |
| with _MRT_LOCK: |
| _MRT = None |
| |
| return {"reloaded": True, "step": step} |
|
|
| @app.post("/model/assets/load") |
| def model_assets_load(repo_id: str = Form(None)): |
| global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID |
| ok, msg = asset_manager.load_finetune_assets_from_hf(repo_id, get_mrt()) |
| |
| _MEAN_EMBED = asset_manager.mean_embed |
| _CENTROIDS = asset_manager.centroids |
| _ASSETS_REPO_ID = asset_manager.assets_repo_id |
| return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID, |
| "mean": _MEAN_EMBED is not None, |
| "centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])} |
|
|
| @app.get("/model/assets/status") |
| def model_assets_status(): |
| d = None |
| try: |
| d = int(get_mrt().style_model.config.embedding_dim) |
| except Exception: |
| pass |
| return { |
| "repo_id": _ASSETS_REPO_ID, |
| "mean_loaded": _MEAN_EMBED is not None, |
| "centroids_loaded": False if _CENTROIDS is None else True, |
| "centroid_count": None if _CENTROIDS is None else int(_CENTROIDS.shape[0]), |
| "embedding_dim": d, |
| } |
|
|
| @app.get("/model/config") |
| def model_config(): |
| """ |
| Lightweight config snapshot: |
| - never calls get_mrt() (no model build / no downloads) |
| - never calls snapshot_download() |
| - reports whether a model instance is currently loaded in memory |
| - best-effort local checkpoint presence (no network) |
| """ |
| |
| with _MRT_LOCK: |
| loaded = (_MRT is not None) |
|
|
| size = os.getenv("MRT_SIZE", "large") |
| repo = os.getenv("MRT_CKPT_REPO") |
| rev = os.getenv("MRT_CKPT_REV", "main") |
| step = os.getenv("MRT_CKPT_STEP") |
| assets = os.getenv("MRT_ASSETS_REPO") |
|
|
| |
| local_ckpt = None |
| if step: |
| try: |
| from pathlib import Path |
| import re |
| step_escaped = re.escape(str(step)) |
| candidates = [] |
| for root in ("/home/appuser/.cache/mrt_ckpt/extracted", |
| "/home/appuser/.cache/mrt_ckpt/repo"): |
| p = Path(root) |
| if not p.exists(): |
| continue |
| |
| for d in p.rglob(f"checkpoint_{step}"): |
| if d.is_dir(): |
| candidates.append(str(d)) |
| local_ckpt = candidates[0] if candidates else None |
| except Exception: |
| local_ckpt = None |
|
|
| return { |
| "size": size, |
| "repo": repo, |
| "revision": rev, |
| "selected_step": step, |
| "assets_repo": assets, |
|
|
| |
| "loaded": loaded, |
| "active_jam": _any_jam_running(), |
| "local_checkpoint_dir": local_ckpt, |
|
|
| |
| "mean_loaded": (_MEAN_EMBED is not None), |
| "centroids_loaded": (_CENTROIDS is not None), |
| "centroid_count": (None if _CENTROIDS is None else int(_CENTROIDS.shape[0])), |
| "warmup_done": bool(_WARMED), |
| } |
|
|
| @app.get("/model/checkpoints") |
| def model_checkpoints(repo_id: str, revision: str = "main"): |
| steps = CheckpointManager.list_ckpt_steps(repo_id, revision) |
| return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)} |
|
|
| @app.post("/model/select") |
| def model_select(req: ModelSelect): |
| """ |
| Swap model/checkpoint/assets. If req.prewarm is True, run the full bar-aligned warmup |
| (_mrt_warmup) synchronously so we only report warmed once the new model is actually ready. |
| """ |
| global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID, _WARMED |
|
|
| |
| success, validation_result = model_selector.validate_selection(req) |
| if not success: |
| if "error" in validation_result: |
| raise HTTPException(status_code=400, detail=validation_result["error"]) |
| return {"ok": False, **validation_result} |
| |
|
|
|
|
| |
| validation_result["active_jam"] = _any_jam_running() |
|
|
| |
| if req.dry_run: |
| return {"ok": True, "dry_run": True, **validation_result} |
| |
| if isinstance(req.step, str) and req.step.lower() == "none": |
| |
| asset_manager.mean_embed = None |
| asset_manager.centroids = None |
| asset_manager.assets_repo_id = None |
| |
| _MEAN_EMBED = None |
| _CENTROIDS = None |
| _ASSETS_REPO_ID = None |
|
|
| |
| if _any_jam_running(): |
| if req.stop_active: |
| _stop_all_jams() |
| else: |
| raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true") |
|
|
| |
| env_changes = model_selector.prepare_env_changes(req, validation_result) |
|
|
| |
| old_env = { |
| "MRT_SIZE": os.getenv("MRT_SIZE"), |
| "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"), |
| "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"), |
| "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"), |
| "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"), |
| } |
|
|
| try: |
| |
| for key, value in env_changes.items(): |
| if value is None: |
| os.environ.pop(key, None) |
| else: |
| os.environ[key] = str(value) |
|
|
| |
| with _MRT_LOCK: |
| _MRT = None |
| with _WARMUP_LOCK: |
| _WARMED = False |
|
|
| |
| if req.sync_assets and validation_result.get("assets_repo"): |
| ok, msg = asset_manager.load_finetune_assets_from_hf( |
| validation_result["assets_repo"], |
| None |
| ) |
| if ok: |
| _MEAN_EMBED = asset_manager.mean_embed |
| _CENTROIDS = asset_manager.centroids |
| _ASSETS_REPO_ID = asset_manager.assets_repo_id |
| else: |
| logging.warning("Asset sync skipped/failed: %s", msg) |
|
|
| |
| |
| |
| if req.prewarm: |
| _mrt_warmup() |
|
|
| |
| |
| |
|
|
| return { |
| "ok": True, |
| **validation_result, |
| "warmup_done": bool(_WARMED), |
| } |
|
|
| except Exception as e: |
| |
| for k, v in old_env.items(): |
| if v is None: |
| os.environ.pop(k, None) |
| else: |
| os.environ[k] = v |
| |
| with _MRT_LOCK: |
| _MRT = None |
| with _WARMUP_LOCK: |
| _WARMED = False |
| logging.exception("Model select failed: %s", e) |
| raise HTTPException(status_code=500, detail=f"Model select failed: {e}") |
| |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| @app.post("/generate") |
| def generate( |
| loop_audio: UploadFile = File(...), |
| bpm: float = Form(...), |
| bars: int = Form(8), |
| beats_per_bar: int = Form(4), |
| styles: str = Form("acid house"), |
| style_weights: str = Form(""), |
| loop_weight: float = Form(1.0), |
| loudness_mode: str = Form("auto"), |
| loudness_headroom_db: float = Form(1.0), |
| guidance_weight: Optional[float] = Form(None), |
| temperature: Optional[float] = Form(None), |
| topk: Optional[int] = Form(None), |
| target_sample_rate: int | None = Form(None), |
| intro_bars_to_drop: int = Form(0), |
| ): |
| ensure_pool_initialized() |
|
|
| |
| mrt_index, mrt = get_available_mrt() |
| if mrt is None: |
| raise HTTPException(status_code=503, detail="All slots busy, retry shortly") |
|
|
| try: |
| |
| defaults = _GLOBAL_GEN_PARAMS.get() |
| guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight'] |
| temperature = temperature if temperature is not None else defaults['temperature'] |
| topk = topk if topk is not None else defaults['topk'] |
|
|
| |
| data = loop_audio.file.read() |
| if not data: |
| return {"error": "Empty file"} |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(data) |
| tmp_path = tmp.name |
|
|
| |
| extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] |
| weights = [float(x) for x in style_weights.split(",")] if style_weights else None |
|
|
| |
| with mrt_overrides(mrt, |
| guidance_weight=guidance_weight, |
| temperature=temperature, |
| topk=topk): |
| wav, loud_stats = generate_loop_continuation_with_mrt( |
| mrt, |
| input_wav_path=tmp_path, |
| bpm=bpm, |
| extra_styles=extra_styles, |
| style_weights=weights, |
| bars=bars, |
| beats_per_bar=beats_per_bar, |
| loop_weight=loop_weight, |
| loudness_mode=loudness_mode, |
| loudness_headroom_db=loudness_headroom_db, |
| intro_bars_to_drop=intro_bars_to_drop, |
| ) |
|
|
| finally: |
| |
| release_mrt(mrt_index) |
|
|
| |
| inp_info = sf.info(tmp_path) |
| input_sr = int(inp_info.samplerate) |
| target_sr = int(target_sample_rate or input_sr) |
|
|
| |
| cur_sr = int(mrt.sample_rate) |
| x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] |
| seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) |
| expected_secs = float(bars) * seconds_per_bar |
| x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs) |
|
|
| |
| audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr) |
| loop_duration_seconds = total_samples / float(target_sr) |
|
|
| |
| metadata = { |
| "bpm": int(round(bpm)), |
| "bars": int(bars), |
| "beats_per_bar": int(beats_per_bar), |
| "styles": extra_styles, |
| "style_weights": weights, |
| "loop_weight": loop_weight, |
| "loudness": loud_stats, |
| "sample_rate": int(target_sr), |
| "channels": int(channels), |
| "crossfade_seconds": mrt.config.crossfade_length, |
| "total_samples": int(total_samples), |
| "seconds_per_bar": seconds_per_bar, |
| "loop_duration_seconds": loop_duration_seconds, |
| "guidance_weight": guidance_weight, |
| "temperature": temperature, |
| "topk": topk, |
| } |
| return {"audio_base64": audio_b64, "metadata": metadata} |
|
|
| |
|
|
| @app.post("/generate_style") |
| def generate_style( |
| bpm: float = Form(...), |
| bars: int = Form(8), |
| beats_per_bar: int = Form(4), |
| styles: str = Form("warmup"), |
| style_weights: str = Form(""), |
| guidance_weight: Optional[float] = Form(None), |
| temperature: Optional[float] = Form(None), |
| topk: Optional[int] = Form(None), |
| target_sample_rate: int | None = Form(None), |
| intro_bars_to_drop: int = Form(0), |
| ): |
| """ |
| Style-only, bar-aligned generation (no input audio). |
| Seeds with 10s of silent context; outputs exactly `bars` at the requested BPM. |
| """ |
| ensure_pool_initialized() |
|
|
| |
| mrt_index, mrt = get_available_mrt() |
| if mrt is None: |
| raise HTTPException(status_code=503, detail="All slots busy, retry shortly") |
|
|
| try: |
| |
| defaults = _GLOBAL_GEN_PARAMS.get() |
| guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight'] |
| temperature = temperature if temperature is not None else defaults['temperature'] |
| topk = topk if topk is not None else defaults['topk'] |
|
|
| |
| with mrt_overrides(mrt, |
| guidance_weight=guidance_weight, |
| temperature=temperature, |
| topk=topk): |
| wav, _ = generate_style_only_with_mrt( |
| mrt, |
| bpm=bpm, |
| bars=bars, |
| beats_per_bar=beats_per_bar, |
| styles=styles, |
| style_weights=style_weights, |
| intro_bars_to_drop=intro_bars_to_drop, |
| ) |
|
|
| |
| cur_sr = int(mrt.sample_rate) |
| target_sr = int(target_sample_rate or cur_sr) |
|
|
| finally: |
| |
| release_mrt(mrt_index) |
| x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] |
|
|
| seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) |
| expected_secs = float(bars) * seconds_per_bar |
|
|
| |
| x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs) |
|
|
| audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr) |
|
|
| metadata = { |
| "bpm": int(round(bpm)), |
| "bars": int(bars), |
| "beats_per_bar": int(beats_per_bar), |
| "styles": [s.strip() for s in (styles.split(",") if styles else []) if s.strip()], |
| "style_weights": [float(y) for y in style_weights.split(",")] if style_weights else None, |
| "sample_rate": int(target_sr), |
| "channels": int(channels), |
| "crossfade_seconds": mrt.config.crossfade_length, |
| "seconds_per_bar": seconds_per_bar, |
| "loop_duration_seconds": total_samples / float(target_sr), |
| "guidance_weight": guidance_weight, |
| "temperature": temperature, |
| "topk": topk, |
| } |
| return {"audio_base64": audio_b64, "metadata": metadata} |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/jam/start") |
| def jam_start( |
| loop_audio: UploadFile = File(...), |
| bpm: float = Form(...), |
| bars_per_chunk: int = Form(4), |
| beats_per_bar: int = Form(4), |
| styles: str = Form(""), |
| style_weights: str = Form(""), |
| loop_weight: float = Form(1.0), |
| |
| |
| mean: float = Form(0.0), |
| centroid_weights: str = Form(""), |
| |
| loudness_mode: str = Form("auto"), |
| loudness_headroom_db: float = Form(1.0), |
| guidance_weight: Optional[float] = Form(None), |
| temperature: Optional[float] = Form(None), |
| topk: Optional[int] = Form(None), |
| target_sample_rate: int | None = Form(None), |
| ): |
| ensure_pool_initialized() |
|
|
| |
| mrt_index, mrt = get_available_mrt() |
| if mrt is None: |
| raise HTTPException(status_code=429, detail="All slots busy (max 2 concurrent JAM sessions)") |
|
|
| try: |
| asset_manager.ensure_assets_loaded(mrt) |
|
|
| |
| defaults = _GLOBAL_GEN_PARAMS.get() |
| guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight'] |
| temperature = temperature if temperature is not None else defaults['temperature'] |
| topk = topk if topk is not None else defaults['topk'] |
|
|
| |
| data = loop_audio.file.read() |
| if not data: raise HTTPException(status_code=400, detail="Empty file") |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(data); tmp_path = tmp.name |
|
|
| loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo() |
|
|
| |
| codec_fps = float(mrt.codec.frame_rate) |
| ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
| loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) |
|
|
| |
| text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] |
| try: |
| tw = [float(x) for x in style_weights.split(",")] if style_weights else [] |
| except ValueError: |
| tw = [] |
| try: |
| cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else [] |
| except ValueError: |
| cw = [] |
|
|
| |
| loop_tail_embed = mrt.embed_style(loop_tail) |
|
|
| |
| |
| |
| style_vec = build_style_vector( |
| mrt, |
| text_styles=text_list, |
| text_weights=tw, |
| loop_embed=loop_tail_embed, |
| loop_weight=float(loop_weight), |
| mean_weight=float(mean), |
| centroid_weights=cw, |
| ).astype(np.float32, copy=False) |
|
|
| |
| inp_info = sf.info(tmp_path) |
| input_sr = int(inp_info.samplerate) |
| target_sr = int(target_sample_rate or input_sr) |
|
|
| params = JamParams( |
| bpm=bpm, |
| beats_per_bar=beats_per_bar, |
| bars_per_chunk=bars_per_chunk, |
| target_sr=target_sr, |
| loudness_mode=loudness_mode, |
| headroom_db=loudness_headroom_db, |
| style_vec=style_vec, |
| ref_loop=loop_tail, |
| combined_loop=loop, |
| guidance_weight=guidance_weight, |
| temperature=temperature, |
| topk=topk |
| ) |
|
|
| worker = JamWorker(mrt, params) |
| sid = str(uuid.uuid4()) |
| with jam_lock: |
| jam_registry[sid] = { |
| 'worker': worker, |
| 'mrt_index': mrt_index |
| } |
| worker.start() |
|
|
| return {"session_id": sid, "slot": mrt_index} |
|
|
| except Exception as e: |
| |
| release_mrt(mrt_index) |
| raise |
|
|
| @app.get("/jam/next") |
| def jam_next(session_id: str): |
| """ |
| Get the next sequential chunk in the jam session. |
| This ensures chunks are delivered in order without gaps. |
| """ |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| |
| chunk = worker.get_next_chunk() |
|
|
| if chunk is None: |
| raise HTTPException(status_code=408, detail="Chunk not ready within timeout") |
|
|
| return { |
| "chunk": { |
| "index": chunk.index, |
| "audio_base64": chunk.audio_base64, |
| "metadata": chunk.metadata |
| } |
| } |
|
|
| @app.post("/jam/consume") |
| def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)): |
| """ |
| Mark a chunk as consumed by the frontend. |
| This helps the worker manage its buffer and generation flow. |
| """ |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker.mark_chunk_consumed(chunk_index) |
|
|
| return {"consumed": chunk_index} |
|
|
|
|
|
|
| @app.post("/jam/stop") |
| def jam_stop(session_id: str = Body(..., embed=True)): |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| mrt_index = session_info['mrt_index'] |
|
|
| worker.stop() |
| worker.join(timeout=10.0) |
| if worker.is_alive(): |
| |
| print(f"⚠️ JamWorker {session_id} did not stop within timeout - keeping MRT slot reserved") |
| |
| return {"stopped": False, "timeout": True, "message": "Worker did not stop in time, retry /jam/stop"} |
|
|
| |
| if mrt_index is not None: |
| release_mrt(mrt_index) |
|
|
| with jam_lock: |
| jam_registry.pop(session_id, None) |
| return {"stopped": True} |
|
|
| @app.post("/jam/stop_all") |
| def jam_stop_all(): |
| """Force stop all active jam sessions (nuclear option for cleanup)""" |
| stopped_sessions = [] |
|
|
| with jam_lock: |
| for session_id, session_info in list(jam_registry.items()): |
| worker = session_info['worker'] |
| mrt_index = session_info['mrt_index'] |
|
|
| if worker.is_alive(): |
| worker.stop() |
| worker.join(timeout=2.0) |
| stopped_sessions.append(session_id) |
|
|
| |
| release_mrt(mrt_index) |
| jam_registry.pop(session_id, None) |
|
|
| return {"stopped_sessions": stopped_sessions, "count": len(stopped_sessions)} |
|
|
| @app.post("/jam/cleanup") |
| def jam_cleanup(force: bool = Form(False), idle_threshold_seconds: float = Form(300.0)): |
| """ |
| Enhanced cleanup endpoint for stopping stale/orphaned sessions. |
| |
| - force=False: Only stops sessions idle for > idle_threshold_seconds (default 5 min) |
| - force=True: Stops ALL sessions regardless of activity (nuclear option) |
| """ |
| stopped_sessions = [] |
| kept_sessions = [] |
| current_time = time.time() |
|
|
| with jam_lock: |
| for session_id, session_info in list(jam_registry.items()): |
| worker = session_info['worker'] |
| mrt_index = session_info['mrt_index'] |
|
|
| |
| should_stop = force |
| idle_time = 0 |
|
|
| if not force and hasattr(worker, 'last_activity_at'): |
| idle_time = current_time - worker.last_activity_at |
| should_stop = idle_time > idle_threshold_seconds |
|
|
| if should_stop: |
| if worker.is_alive(): |
| worker.stop() |
| worker.join(timeout=10.0) |
|
|
| |
| release_mrt(mrt_index) |
| jam_registry.pop(session_id, None) |
| stopped_sessions.append({ |
| "session_id": session_id, |
| "idle_seconds": round(idle_time, 1), |
| "slot": mrt_index |
| }) |
| else: |
| kept_sessions.append({ |
| "session_id": session_id, |
| "idle_seconds": round(idle_time, 1), |
| "slot": mrt_index |
| }) |
|
|
| return { |
| "stopped": stopped_sessions, |
| "kept": kept_sessions, |
| "stopped_count": len(stopped_sessions), |
| "kept_count": len(kept_sessions), |
| "force": force, |
| "idle_threshold_seconds": idle_threshold_seconds |
| } |
|
|
| @app.post("/jam/update") |
| def jam_update( |
| session_id: str = Form(...), |
| |
| |
| guidance_weight: Optional[float] = Form(None), |
| temperature: Optional[float] = Form(None), |
| topk: Optional[int] = Form(None), |
| |
| |
| styles: str = Form(""), |
| style_weights: str = Form(""), |
| loop_weight: Optional[float] = Form(None), |
| use_current_mix_as_style: bool = Form(False), |
| |
| |
| mean: Optional[float] = Form(None), |
| centroid_weights: str = Form(""), |
| ): |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| |
| mrt = _MRT_POOL[session_info['mrt_index']] |
| asset_manager.ensure_assets_loaded(mrt) |
|
|
| |
| if any(v is not None for v in (guidance_weight, temperature, topk)): |
| worker.update_knobs( |
| guidance_weight=guidance_weight, |
| temperature=temperature, |
| topk=topk |
| ) |
|
|
| |
| wants_style_update = ( |
| use_current_mix_as_style |
| or (styles.strip() != "") |
| or (mean is not None) |
| or (centroid_weights.strip() != "") |
| ) |
| if not wants_style_update: |
| return {"ok": True} |
|
|
| |
| text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] |
| try: |
| tw = [float(x) for x in style_weights.split(",")] if style_weights else [] |
| except ValueError: |
| tw = [] |
| try: |
| cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else [] |
| except ValueError: |
| cw = [] |
|
|
| |
| max_c = 0 if _CENTROIDS is None else int(_CENTROIDS.shape[0]) |
| if max_c and len(cw) > max_c: |
| cw = cw[:max_c] |
|
|
| |
| with worker._lock: |
| combined_loop = worker.params.combined_loop if use_current_mix_as_style else None |
| lw = None |
| if use_current_mix_as_style: |
| lw = 1.0 if (loop_weight is None) else float(loop_weight) |
| mrt = worker.mrt |
|
|
| |
| loop_embed = None |
| if combined_loop is not None: |
| loop_embed = mrt.embed_style(combined_loop) |
|
|
| style_vec = build_style_vector( |
| mrt, |
| text_styles=text_list, |
| text_weights=tw, |
| loop_embed=loop_embed, |
| loop_weight=lw, |
| mean_weight=(None if mean is None else float(mean)), |
| centroid_weights=cw, |
| ).astype(np.float32, copy=False) |
|
|
| |
| with worker._lock: |
| worker.params.style_vec = style_vec |
|
|
| return {"ok": True} |
|
|
|
|
| @app.post("/jam/reseed") |
| def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)): |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| |
| if loop_audio is not None: |
| data = loop_audio.file.read() |
| if not data: |
| raise HTTPException(status_code=400, detail="Empty file") |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(data); path = tmp.name |
| wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo() |
| else: |
| |
| |
|
|
| s = getattr(worker, "_stream", None) |
| if s is None or s.shape[0] == 0: |
| raise HTTPException(status_code=400, detail="No internal stream to reseed from") |
| wav = au.Waveform(s.astype(np.float32, copy=False), int(worker.mrt.sample_rate)).as_stereo() |
|
|
| worker.reseed_from_waveform(wav) |
| return {"ok": True} |
|
|
| @app.post("/jam/reseed_splice") |
| def jam_reseed_splice( |
| session_id: str = Form(...), |
| anchor_bars: float = Form(2.0), |
| combined_audio: UploadFile = File(None), |
| ): |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| |
|
|
| wav = None |
|
|
| if combined_audio is not None: |
| data = combined_audio.file.read() |
| if not data: |
| raise HTTPException(status_code=400, detail="Empty combined_audio") |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(data) |
| path = tmp.name |
| wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo() |
| else: |
| |
| s = getattr(worker, "_stream", None) |
| if s is None or s.shape[0] == 0: |
| raise HTTPException(status_code=400, detail="No audio available to reseed from") |
| wav = au.Waveform(s.astype(np.float32, copy=False), int(worker.mrt.sample_rate)).as_stereo() |
|
|
| |
| worker.reseed_splice(wav, anchor_bars=float(anchor_bars)) |
| return {"ok": True, "anchor_bars": float(anchor_bars)} |
|
|
| @app.get("/jam/status") |
| def jam_status(session_id: str): |
| with jam_lock: |
| session_info = jam_registry.get(session_id) |
|
|
| if session_info is None: |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| worker = session_info['worker'] |
| running = worker.is_alive() |
|
|
| |
| with worker._lock: |
| last_generated = int(worker.idx) |
| last_delivered = int(worker._last_delivered_index) |
| queued = len(worker.outbox) |
| buffer_ahead = last_generated - last_delivered |
| p = worker.params |
| spb = p.beats_per_bar * (60.0 / p.bpm) |
| chunk_secs = p.bars_per_chunk * spb |
|
|
| return { |
| "running": running, |
| "last_generated_index": last_generated, |
| "last_delivered_index": last_delivered, |
| "buffer_ahead": buffer_ahead, |
| "queued_chunks": queued, |
| "bpm": p.bpm, |
| "beats_per_bar": p.beats_per_bar, |
| "bars_per_chunk": p.bars_per_chunk, |
| "seconds_per_bar": spb, |
| "chunk_duration_seconds": chunk_secs, |
| "target_sample_rate": p.target_sr, |
| "last_chunk_started_at": worker.last_chunk_started_at, |
| "last_chunk_completed_at": worker.last_chunk_completed_at, |
| } |
|
|
| @app.get("/jam/sessions") |
| def jam_sessions(): |
| """List all active JAM sessions with metadata for monitoring""" |
| sessions = [] |
| current_time = time.time() |
|
|
| with jam_lock: |
| for session_id, session_info in jam_registry.items(): |
| worker = session_info['worker'] |
| mrt_index = session_info['mrt_index'] |
|
|
| |
| uptime = current_time - worker.created_at if hasattr(worker, 'created_at') else 0 |
| last_activity = worker.last_activity_at if hasattr(worker, 'last_activity_at') else worker.created_at if hasattr(worker, 'created_at') else current_time |
| idle_time = current_time - last_activity |
|
|
| |
| with worker._lock: |
| last_generated = int(worker.idx) |
| last_delivered = int(worker._last_delivered_index) |
| queued = len(worker.outbox) |
|
|
| sessions.append({ |
| "session_id": session_id, |
| "mrt_slot": mrt_index, |
| "running": worker.is_alive(), |
| "uptime_seconds": round(uptime, 1), |
| "idle_seconds": round(idle_time, 1), |
| "chunks_generated": last_generated, |
| "chunks_delivered": last_delivered, |
| "chunks_queued": queued, |
| "bpm": worker.params.bpm, |
| "bars_per_chunk": worker.params.bars_per_chunk, |
| "created_at": worker.created_at if hasattr(worker, 'created_at') else None, |
| "last_activity_at": last_activity, |
| }) |
|
|
| return { |
| "sessions": sessions, |
| "total_active": len(sessions), |
| "mrt_pool_size": len(_MRT_POOL), |
| } |
|
|
| @app.get("/jam/sessions") |
| def jam_sessions(): |
| """List all active JAM sessions with metadata for monitoring""" |
| sessions = [] |
| current_time = time.time() |
|
|
| with jam_lock: |
| for session_id, session_info in jam_registry.items(): |
| worker = session_info['worker'] |
| mrt_index = session_info['mrt_index'] |
|
|
| |
| uptime = current_time - worker.created_at if hasattr(worker, 'created_at') else 0 |
| last_activity = worker.last_activity_at if hasattr(worker, 'last_activity_at') else (worker.created_at if hasattr(worker, 'created_at') else current_time) |
| idle_time = current_time - last_activity |
|
|
| |
| with worker._lock: |
| last_generated = int(worker.idx) |
| last_delivered = int(worker._next_to_deliver) - 1 |
| queued = len(worker._outbox) |
|
|
| sessions.append({ |
| "session_id": session_id, |
| "mrt_slot": mrt_index, |
| "running": worker.is_alive(), |
| "uptime_seconds": round(uptime, 1), |
| "idle_seconds": round(idle_time, 1), |
| "chunks_generated": last_generated, |
| "chunks_delivered": last_delivered, |
| "chunks_queued": queued, |
| "bpm": worker.params.bpm, |
| "bars_per_chunk": worker.params.bars_per_chunk, |
| "created_at": worker.created_at if hasattr(worker, 'created_at') else None, |
| "last_activity_at": last_activity, |
| }) |
|
|
| return { |
| "sessions": sessions, |
| "total_active": len(sessions), |
| "mrt_pool_size": len(_MRT_POOL), |
| } |
|
|
| @app.get("/health") |
| def health(): |
| |
| if SPACE_MODE != "serve": |
| return JSONResponse( |
| status_code=503, |
| content={ |
| "ok": False, |
| "status": "template_mode", |
| "message": "This Space is a GPU template. Duplicate it and select an L40s/A100-class runtime to use the API.", |
| "mode": SPACE_MODE, |
| }, |
| ) |
|
|
| |
| probe = _gpu_probe() |
| if not probe["ok"] or not probe["has_gpu"] or probe.get("backend") != "gpu": |
| return JSONResponse( |
| status_code=503, |
| content={ |
| "ok": False, |
| "status": "gpu_unavailable", |
| "message": "GPU is not visible to JAX. Select a GPU runtime (e.g., L40s) to serve.", |
| "probe": probe, |
| "mode": SPACE_MODE, |
| }, |
| ) |
|
|
| |
| warmed = bool(_WARMED) |
| with jam_lock: |
| active_jams = sum(1 for info in jam_registry.values() if info['worker'].is_alive()) |
| return { |
| "ok": True, |
| "status": "ready" if warmed else "initializing", |
| "mode": SPACE_MODE, |
| "warmed": warmed, |
| "active_jams": active_jams, |
| "probe": probe, |
| } |
|
|
| @app.middleware("http") |
| async def log_requests(request: Request, call_next): |
| rid = request.headers.get("X-Request-ID", "-") |
| print(f"📥 {request.method} {request.url.path}?{request.url.query} [rid={rid}]") |
| try: |
| response = await call_next(request) |
| except Exception as e: |
| print(f"💥 exception for {request.url.path} [rid={rid}]: {e}") |
| raise |
| print(f"📤 {response.status_code} {request.url.path} [rid={rid}]") |
| return response |
|
|
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.websocket("/ws/jam") |
| async def ws_jam(websocket: WebSocket): |
| await websocket.accept() |
| sid = None |
| worker = None |
| binary_audio = False |
| mode = "rt" |
|
|
| |
| async def send_json(obj): |
| return await send_json_safe(websocket, obj) |
|
|
| try: |
| while True: |
| raw = await websocket.receive_text() |
| msg = json.loads(raw) |
| mtype = msg.get("type") |
|
|
| |
| if mtype == "start": |
| binary_audio = bool(msg.get("binary_audio", False)) |
| mode = msg.get("mode", "rt") |
| params = msg.get("params", {}) or {} |
| sid = msg.get("session_id") |
|
|
| |
| if sid: |
| with jam_lock: |
| session_info = jam_registry.get(sid) |
| if session_info is None: |
| await send_json({"type":"error","error":"Session not found"}) |
| continue |
| worker = session_info['worker'] |
| if not worker.is_alive(): |
| await send_json({"type":"error","error":"Session not found"}) |
| continue |
| else: |
| |
| if mode == "bar": |
| loop_b64 = msg.get("loop_audio_b64") |
| if not loop_b64: |
| await send_json({"type":"error","error":"loop_audio_b64 required for mode=bar when no session_id"}) |
| continue |
| loop_bytes = base64.b64decode(loop_b64) |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(loop_bytes); tmp_path = tmp.name |
| |
| mrt = get_mrt() |
| model_sr = int(mrt.sample_rate) |
| |
| target_sr = int(params.get("target_sr", model_sr)) |
| loudness_mode = params.get("loudness_mode", "none") |
| headroom_db = float(params.get("headroom_db", 1.0)) |
| loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo() |
|
|
| codec_fps = float(mrt.codec.frame_rate) |
| ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
| bpm = float(params.get("bpm", 120.0)) |
| bpb = int(params.get("beats_per_bar", 4)) |
| loop_tail = take_bar_aligned_tail(loop, bpm, bpb, ctx_seconds) |
|
|
| |
| embeds, weights = [mrt.embed_style(loop_tail)], [float(params.get("loop_weight", 1.0))] |
| extra = [s for s in (params.get("styles","").split(",")) if s.strip()] |
| sw = [float(x) for x in params.get("style_weights","").split(",") if x.strip()] |
| for i, s in enumerate(extra): |
| embeds.append(mrt.embed_style(s.strip())) |
| weights.append(sw[i] if i < len(sw) else 1.0) |
| wsum = sum(weights) or 1.0 |
| weights = [w/wsum for w in weights] |
| style_vec = np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32) |
|
|
| |
| inp_info = sf.info(tmp_path) |
| target_sr = int(params.get("target_sr", int(inp_info.samplerate))) |
|
|
| |
| jp = JamParams( |
| bpm=bpm, beats_per_bar=bpb, bars_per_chunk=int(params.get("bars_per_chunk", 8)), |
| target_sr=target_sr, |
| loudness_mode=loudness_mode, headroom_db=headroom_db, |
| style_vec=style_vec, |
| ref_loop=None if loudness_mode == "none" else loop_tail, |
| combined_loop=loop, |
| guidance_weight=float(params.get("guidance_weight", 1.1)), |
| temperature=float(params.get("temperature", 1.1)), |
| topk=int(params.get("topk", 40)), |
| ) |
| worker = JamWorker(get_mrt(), jp) |
| sid = str(uuid.uuid4()) |
| with jam_lock: |
| |
| for _sid, info in list(jam_registry.items()): |
| if info['worker'].is_alive(): |
| await send_json({"type":"error","error":"A jam is already running"}) |
| worker = None; sid = None |
| break |
| if worker is not None: |
| jam_registry[sid] = {'worker': worker, 'mrt_index': None} |
| worker.start() |
|
|
| else: |
| |
| mrt = get_mrt() |
| state = mrt.init_state() |
|
|
| |
| codec_fps = float(mrt.codec.frame_rate) |
| ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
| sr = int(mrt.sample_rate) |
| samples = int(max(1, round(ctx_seconds * sr))) |
| silent = au.Waveform(np.zeros((samples, 2), np.float32), sr) |
| tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth] |
| state.context_tokens = tokens |
|
|
| |
| asset_manager.ensure_assets_loaded(get_mrt()) |
| styles_str = params.get("styles", "warmup") or "" |
| style_weights_str = params.get("style_weights", "") or "" |
| mean_w = float(params.get("mean", 0.0) or 0.0) |
| cw_str = str(params.get("centroid_weights", "") or "") |
|
|
| text_list = [s.strip() for s in styles_str.split(",") if s.strip()] |
| try: |
| text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else [] |
| except ValueError: |
| text_w = [] |
| try: |
| cw = [float(x) for x in cw_str.split(",") if x.strip() != ""] |
| except ValueError: |
| cw = [] |
|
|
| |
| if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]): |
| cw = cw[: int(_CENTROIDS.shape[0])] |
|
|
| |
| style_vec = build_style_vector( |
| mrt, |
| text_styles=text_list, |
| text_weights=text_w, |
| loop_embed=None, |
| loop_weight=None, |
| mean_weight=mean_w, |
| centroid_weights=cw, |
| ) |
|
|
| |
| websocket._mrt = mrt |
| websocket._state = state |
| websocket._style_cur = style_vec |
| websocket._style_tgt = style_vec |
| websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0)) |
|
|
| websocket._rt_mean = mean_w |
| websocket._rt_centroid_weights = cw |
| websocket._rt_running = True |
| websocket._rt_sr = sr |
| websocket._rt_topk = int(params.get("topk", 40)) |
| websocket._rt_temp = float(params.get("temperature", 1.1)) |
| websocket._rt_guid = float(params.get("guidance_weight", 1.1)) |
| websocket._pace = params.get("pace", "asap") |
|
|
| |
| assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None) |
| await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"}) |
|
|
| |
| async def _rt_loop(): |
| try: |
| mrt = websocket._mrt |
| chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate) |
| target_next = time.perf_counter() |
| while websocket._rt_running: |
| mrt.guidance_weight = websocket._rt_guid |
| mrt.temperature = websocket._rt_temp |
| mrt.topk = websocket._rt_topk |
|
|
| |
| ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0) |
| if ramp <= 0.0: |
| websocket._style_cur = websocket._style_tgt |
| else: |
| step = min(1.0, chunk_secs / ramp) |
| websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur) |
|
|
| wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur) |
| websocket._state = new_state |
|
|
| x = wav.samples.astype(np.float32, copy=False) |
| buf = io.BytesIO() |
| sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV") |
|
|
| ok = True |
| if binary_audio: |
| try: |
| await websocket.send_bytes(buf.getvalue()) |
| ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}}) |
| except Exception: |
| ok = False |
| else: |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| ok = await send_json({"type": "chunk", "audio_base64": b64, |
| "metadata": {"sample_rate": mrt.sample_rate}}) |
|
|
| if not ok: |
| break |
|
|
| if getattr(websocket, "_pace", "asap") == "realtime": |
| t1 = time.perf_counter() |
| target_next += chunk_secs |
| sleep_s = max(0.0, target_next - t1 - 0.02) |
| if sleep_s > 0: |
| await asyncio.sleep(sleep_s) |
| except asyncio.CancelledError: |
| pass |
| except Exception: |
| pass |
|
|
| websocket._rt_task = asyncio.create_task(_rt_loop()) |
| continue |
|
|
| await send_json({"type":"started","session_id": sid, "mode": mode}) |
|
|
| |
| if mode == "bar" and worker is not None: |
| async def _pump(): |
| while True: |
| if not worker.is_alive(): |
| break |
| chunk = worker.get_next_chunk(timeout=60.0) |
| if chunk is None: |
| continue |
| if binary_audio: |
| await websocket.send_bytes(base64.b64decode(chunk.audio_base64)) |
| await send_json({"type":"chunk_meta","index":chunk.index,"metadata":chunk.metadata}) |
| else: |
| await send_json({"type":"chunk","index":chunk.index, |
| "audio_base64":chunk.audio_base64,"metadata":chunk.metadata}) |
| asyncio.create_task(_pump()) |
|
|
| |
| elif mtype == "update": |
| if mode == "bar": |
| if not sid: |
| await send_json({"type":"error","error":"No session_id yet"}); return |
| |
| res = jam_update( |
| session_id=sid, |
| guidance_weight=msg.get("guidance_weight"), |
| temperature=msg.get("temperature"), |
| topk=msg.get("topk"), |
| styles=msg.get("styles",""), |
| style_weights=msg.get("style_weights",""), |
| loop_weight=msg.get("loop_weight"), |
| use_current_mix_as_style=bool(msg.get("use_current_mix_as_style", False)), |
| ) |
| await send_json({"type":"status", **res}) |
| else: |
| |
| websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp)) |
| websocket._rt_topk = int(msg.get("topk", websocket._rt_topk)) |
| websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid)) |
|
|
| |
| if "mean" in msg and msg["mean"] is not None: |
| try: websocket._rt_mean = float(msg["mean"]) |
| except: websocket._rt_mean = 0.0 |
|
|
| if "centroid_weights" in msg: |
| cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""] |
| try: |
| websocket._rt_centroid_weights = [float(x) for x in cw] |
| except: |
| websocket._rt_centroid_weights = [] |
|
|
| |
| styles_str = msg.get("styles", None) |
| style_weights_str = msg.get("style_weights", "") |
|
|
| text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()] |
| text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else [] |
|
|
| asset_manager.ensure_assets_loaded(get_mrt()) |
| websocket._style_tgt = build_style_vector( |
| websocket._mrt, |
| text_styles=text_list, |
| text_weights=text_w, |
| loop_embed=None, |
| loop_weight=None, |
| mean_weight=float(websocket._rt_mean), |
| centroid_weights=websocket._rt_centroid_weights, |
| ) |
| |
| if "style_ramp_seconds" in msg: |
| try: websocket._style_ramp_s = float(msg["style_ramp_seconds"]) |
| except: pass |
| await send_json({"type":"status","updated":"rt-knobs+style"}) |
|
|
| elif mtype == "consume" and mode == "bar": |
| with jam_lock: |
| session_info = jam_registry.get(msg.get("session_id")) |
| if session_info is not None: |
| session_info['worker'].mark_chunk_consumed(int(msg.get("chunk_index", -1))) |
|
|
| elif mtype == "reseed" and mode == "bar": |
| with jam_lock: |
| session_info = jam_registry.get(msg.get("session_id")) |
| if session_info is None or not session_info['worker'].is_alive(): |
| await send_json({"type":"error","error":"Session not found"}); continue |
| worker = session_info['worker'] |
| loop_b64 = msg.get("loop_audio_b64") |
| if not loop_b64: |
| await send_json({"type":"error","error":"loop_audio_b64 required"}); continue |
| loop_bytes = base64.b64decode(loop_b64) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(loop_bytes); path = tmp.name |
| wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo() |
| worker.reseed_from_waveform(wav) |
| await send_json({"type":"status","reseeded":True}) |
|
|
| elif mtype == "reseed_splice" and mode == "bar": |
| with jam_lock: |
| session_info = jam_registry.get(msg.get("session_id")) |
| if session_info is None or not session_info['worker'].is_alive(): |
| await send_json({"type":"error","error":"Session not found"}); continue |
| worker = session_info['worker'] |
| anchor = float(msg.get("anchor_bars", 2.0)) |
| b64 = msg.get("combined_audio_b64") |
| if b64: |
| data = base64.b64decode(b64) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(data); path = tmp.name |
| wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo() |
| worker.reseed_splice(wav, anchor_bars=anchor) |
| else: |
| |
| worker.reseed_splice(worker.params.combined_loop, anchor_bars=anchor) |
| await send_json({"type":"status","splice":anchor}) |
|
|
| elif mtype == "stop": |
| if mode == "rt": |
| websocket._rt_running = False |
| task = getattr(websocket, "_rt_task", None) |
| if task is not None: |
| task.cancel() |
| try: await task |
| except asyncio.CancelledError: pass |
| await send_json({"type":"stopped"}) |
| break |
|
|
| elif mtype == "ping": |
| await send_json({"type":"pong"}) |
|
|
| else: |
| await send_json({"type":"error","error":f"Unknown type {mtype}"}) |
|
|
| except WebSocketDisconnect: |
| |
| pass |
| except Exception as e: |
| try: |
| await send_json({"type":"error","error":str(e)}) |
| except Exception: |
| pass |
| finally: |
| try: |
| if websocket.client_state != WebSocketState.DISCONNECTED: |
| await websocket.close() |
| except Exception: |
| pass |
|
|
|
|
| @app.get("/ping") |
| def ping(): |
| return {"ok": True} |
|
|
| @app.get("/", response_class=Response) |
| def read_root(): |
| """Root endpoint that explains what this API does""" |
| try: |
| html_file = Path(__file__).parent / "documentation.html" |
| html_content = html_file.read_text(encoding='utf-8') |
| except FileNotFoundError: |
| |
| html_content = """ |
| <!DOCTYPE html> |
| <html><body> |
| <h1>MagentaRT Research API</h1> |
| <p>Documentation file not found. Please check documentation.html</p> |
| </body></html> |
| """ |
| return Response(content=html_content, media_type="text/html") |
|
|
| |
| |
| |
|
|
| @app.get("/config/generation") |
| async def get_generation_config(): |
| """ |
| Get current global defaults for temperature, topk, and guidance_weight. |
| These defaults are applied at MRT initialization and affect all new requests. |
| """ |
| return _GLOBAL_GEN_PARAMS.get() |
|
|
| @app.put("/config/generation") |
| async def update_generation_config( |
| temperature: Optional[float] = None, |
| topk: Optional[int] = None, |
| guidance_weight: Optional[float] = None |
| ): |
| """ |
| Update global defaults for temperature, topk, and guidance_weight. |
| |
| NOTE: Changes require MRT pool restart to take effect. |
| Call POST /config/generation/apply after updating to apply changes. |
| |
| Per-request overrides still work - explicit parameters in requests |
| will override these global defaults. |
| """ |
| return { |
| "updated": _GLOBAL_GEN_PARAMS.update( |
| temperature=temperature, |
| topk=topk, |
| guidance_weight=guidance_weight |
| ), |
| "note": "Changes require pool restart. Call POST /config/generation/apply to apply." |
| } |
|
|
| @app.post("/config/generation/apply") |
| async def apply_generation_config(): |
| """ |
| Restart MRT pool with new global parameters. |
| |
| This will: |
| 1. Check if any JAM sessions are active |
| 2. If active sessions exist, return 409 error |
| 3. If no active sessions, recreate MRT pool with new parameters |
| |
| All future requests will use the new global defaults. |
| """ |
| |
| with jam_lock: |
| active_sessions = [] |
| for sid, session_info in jam_registry.items(): |
| if session_info['worker'].is_alive(): |
| active_sessions.append(sid) |
|
|
| if active_sessions: |
| raise HTTPException( |
| status_code=409, |
| detail=f"Cannot restart: {len(active_sessions)} active JAM session(s). Stop them first via /jam/stop" |
| ) |
|
|
| |
| reset_mrt_pool() |
|
|
| return { |
| "status": "applied", |
| "params": _GLOBAL_GEN_PARAMS.get(), |
| "message": "MRT pool restarted with new parameters" |
| } |
|
|
| @app.get("/config/generation/pool_status") |
| async def get_pool_status(): |
| """Get current MRT pool status and availability""" |
| with _MRT_POOL_LOCK: |
| return { |
| "pool_size": len(_MRT_POOL), |
| "available": _MRT_AVAILABLE.copy(), |
| "initialized": _POOL_INITIALIZED, |
| "params": _GLOBAL_GEN_PARAMS.get() |
| } |
|
|
| |
| |
| |
|
|
| @app.get("/lil_demo_540p.mp4") |
| def demo_video(): |
| return FileResponse(Path(__file__).parent / "lil_demo_540p.mp4", media_type="video/mp4") |
|
|
| @app.get("/tester", response_class=HTMLResponse) |
| def tester(): |
| html_path = Path(__file__).parent / "magentaRT_rt_tester.html" |
| return HTMLResponse( |
| html_path.read_text(encoding="utf-8"), |
| headers={"Cache-Control": "no-store"} |
| ) |
|
|
| @app.get("/magenta_prompts.js") |
| def magenta_prompts_js(): |
| js_path = Path(__file__).parent / "magenta_prompts.js" |
| return FileResponse(js_path, media_type="text/javascript", |
| headers={"Cache-Control": "no-store"}) |