| |
| from __future__ import annotations |
|
|
| import os |
|
|
| import threading, time |
| from dataclasses import dataclass |
| from fractions import Fraction |
| from typing import Optional, Dict, Tuple, List |
|
|
| import numpy as np |
| from magenta_rt import audio as au |
|
|
| from utils import ( |
| StreamingResampler, |
| match_loudness_to_reference, |
| make_bar_aligned_context, |
| take_bar_aligned_tail, |
| wav_bytes_base64, |
| ) |
|
|
| def _dbg_rms_dbfs(x: np.ndarray) -> float: |
| |
| if x.ndim == 2: |
| x = x.mean(axis=1) |
| r = float(np.sqrt(np.mean(x * x) + 1e-12)) |
| return 20.0 * np.log10(max(r, 1e-12)) |
|
|
| def _dbg_rms_dbfs_model(x: np.ndarray) -> float: |
| |
| |
| if x.ndim == 2: |
| x = x.mean(axis=1) |
| r = float(np.sqrt(np.mean(x * x) + 1e-12)) |
| return 20.0 * np.log10(max(r, 1e-12)) |
|
|
| def _dbg_shape(x): |
| return tuple(x.shape) if hasattr(x, "shape") else ("-",) |
|
|
| |
| |
| |
|
|
| @dataclass |
| class JamParams: |
| bpm: float |
| beats_per_bar: int |
| bars_per_chunk: int |
| target_sr: int |
| loudness_mode: str = "auto" |
| headroom_db: float = 1.0 |
| style_vec: Optional[np.ndarray] = None |
| ref_loop: Optional[au.Waveform] = None |
| combined_loop: Optional[au.Waveform] = None |
| guidance_weight: float = 1.1 |
| temperature: float = 1.1 |
| topk: int = 40 |
| style_ramp_seconds: float = 8.0 |
|
|
|
|
| @dataclass |
| class JamChunk: |
| index: int |
| audio_base64: str |
| metadata: dict |
|
|
|
|
| |
| |
| |
|
|
| class BarClock: |
| """Sample-domain bar clock with drift-free absolute boundaries.""" |
| def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0): |
| self.sr = int(target_sr) |
| self.bpm = Fraction(str(bpm)) |
| self.beats_per_bar = int(beats_per_bar) |
| self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm |
| self.base = int(base_offset_samples) |
|
|
| def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]: |
| start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk) |
| end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk) |
| return int(round(start_f)), int(round(end_f)) |
|
|
| def seconds_per_bar(self) -> float: |
| return float(self.beats_per_bar) * (60.0 / float(self.bpm)) |
|
|
|
|
| |
| |
| |
|
|
| class JamWorker(threading.Thread): |
| FRAMES_PER_SECOND: float | None = None |
| """Generates continuous audio with MagentaRT, spools it at target SR, |
| and emits *sample-accurate*, bar-aligned chunks (no FPS drift).""" |
|
|
| def __init__(self, mrt, params: JamParams): |
| super().__init__(daemon=True) |
| self.mrt = mrt |
| self.params = params |
|
|
| |
| self._lock = threading.RLock() |
|
|
| |
| self.state = self.mrt.init_state() |
| self.mrt.guidance_weight = float(self.params.guidance_weight) |
| self.mrt.temperature = float(self.params.temperature) |
| self.mrt.topk = int(self.params.topk) |
|
|
| |
|
|
| |
| self._codec_fps = float(self.mrt.codec.frame_rate) |
| JamWorker.FRAMES_PER_SECOND = self._codec_fps |
| self._ctx_frames = int(self.mrt.config.context_length_frames) |
| self._ctx_seconds = self._ctx_frames / self._codec_fps |
|
|
| |
| self._model_stream: Optional[np.ndarray] = None |
| self._model_sr = int(self.mrt.sample_rate) |
|
|
| |
| self._style_vec = (None if self.params.style_vec is None |
| else np.array(self.params.style_vec, dtype=np.float32, copy=True)) |
| self._chunk_secs = ( |
| self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples |
| ) / float(self._model_sr) |
|
|
| |
| if int(self.params.target_sr) != int(self._model_sr): |
| self._rs = StreamingResampler(self._model_sr, int(self.params.target_sr), channels=2) |
| else: |
| self._rs = None |
| self._spool = np.zeros((0, 2), dtype=np.float32) |
| self._spool_written = 0 |
|
|
| self._pending_tail_model = None |
| self._pending_tail_target_len = 0 |
|
|
| |
| self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0) |
|
|
| |
| self.idx = 0 |
| self._next_to_deliver = 0 |
| self._last_consumed_index = -1 |
|
|
| |
| self._outbox: Dict[int, JamChunk] = {} |
| self._cv = threading.Condition() |
|
|
| |
| self._stop_event = threading.Event() |
| self._max_buffer_ahead = 1 |
|
|
| |
| self.created_at = time.time() |
| self.last_chunk_started_at: Optional[float] = None |
| self.last_chunk_completed_at: Optional[float] = None |
| self.last_activity_at = time.time() |
|
|
| |
| self._pending_reseed: Optional[dict] = None |
| self._pending_token_splice: Optional[dict] = None |
|
|
| |
| if self.params.combined_loop is not None: |
| self._install_context_from_loop(self.params.combined_loop) |
|
|
| |
|
|
| def set_buffer_seconds(self, seconds: float): |
| """Clamp how far ahead we allow, in *seconds* of audio.""" |
| chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar() |
| max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6)))) |
| with self._cv: |
| self._max_buffer_ahead = max_chunks |
|
|
| def set_buffer_chunks(self, k: int): |
| with self._cv: |
| self._max_buffer_ahead = max(0, int(k)) |
|
|
| def stop(self): |
| self._stop_event.set() |
|
|
| |
| def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]: |
| deadline = time.time() + timeout |
| with self._cv: |
| while True: |
| c = self._outbox.get(self._next_to_deliver) |
| if c is not None: |
| self._next_to_deliver += 1 |
| return c |
| remaining = deadline - time.time() |
| if remaining <= 0: |
| return None |
| self._cv.wait(timeout=min(0.25, remaining)) |
|
|
| def mark_chunk_consumed(self, chunk_index: int): |
| |
| with self._cv: |
| self._last_consumed_index = max(self._last_consumed_index, int(chunk_index)) |
| |
| for k in list(self._outbox.keys()): |
| if k < self._last_consumed_index - 1: |
| self._outbox.pop(k, None) |
|
|
| def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): |
| with self._lock: |
| if guidance_weight is not None: |
| self.params.guidance_weight = float(guidance_weight) |
| if temperature is not None: |
| self.params.temperature = float(temperature) |
| if topk is not None: |
| self.params.topk = int(topk) |
| |
| self.mrt.guidance_weight = float(self.params.guidance_weight) |
| self.mrt.temperature = float(self.params.temperature) |
| self.mrt.topk = int(self.params.topk) |
|
|
| |
|
|
| def _expected_token_shape(self) -> Tuple[int, int]: |
| F = int(self._ctx_frames) |
| D = int(self.mrt.config.decoder_codec_rvq_depth) |
| return F, D |
|
|
| def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray: |
| """Force tokens to (context_length_frames, rvq_depth), padding/trimming as needed. |
| Pads missing frames by repeating the last frame (safer than zeros for RVQ stacks).""" |
| F, D = self._expected_token_shape() |
| if toks.ndim != 2: |
| toks = np.atleast_2d(toks) |
| |
| if toks.shape[1] > D: |
| toks = toks[:, :D] |
| elif toks.shape[1] < D: |
| pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1])) |
| toks = np.concatenate([toks, pad_cols], axis=1) |
| |
| if toks.shape[0] < F: |
| if toks.shape[0] == 0: |
| toks = np.zeros((1, D), dtype=np.int32) |
| pad = np.repeat(toks[-1:, :], F - toks.shape[0], axis=0) |
| toks = np.concatenate([pad, toks], axis=0) |
| elif toks.shape[0] > F: |
| toks = toks[-F:, :] |
| if toks.dtype != np.int32: |
| toks = toks.astype(np.int32, copy=False) |
| return toks |
|
|
| def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray: |
| """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps), |
| while ensuring the *end* of the audio lands on a bar boundary. |
| Strategy: take the largest integer number of bars <= ctx_seconds as the tail, |
| then left-fill from just before that tail (wrapping if needed) to reach exactly |
| ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim |
| tokens to the expected frame count. |
| """ |
| wav = loop.as_stereo().resample(self._model_sr) |
| data = wav.samples.astype(np.float32, copy=False) |
| if data.ndim == 1: |
| data = data[:, None] |
|
|
| spb = self._bar_clock.seconds_per_bar() |
| ctx_sec = float(self._ctx_seconds) |
| sr = int(self._model_sr) |
|
|
| |
| bars_fit = max(1, int(ctx_sec // spb)) |
| tail_len_samps = int(round(bars_fit * spb * sr)) |
|
|
| |
| need = int(round(ctx_sec * sr)) + tail_len_samps |
| if data.shape[0] == 0: |
| data = np.zeros((1, 2), dtype=np.float32) |
| reps = int(np.ceil(need / float(data.shape[0]))) |
| tiled = np.tile(data, (reps, 1)) |
|
|
| end = tiled.shape[0] |
| tail = tiled[end - tail_len_samps:end] |
|
|
| |
| ctx_samps = int(round(ctx_sec * sr)) |
| pad_len = ctx_samps - tail.shape[0] |
| if pad_len > 0: |
| pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps] |
| ctx = np.concatenate([pre, tail], axis=0) |
| else: |
| ctx = tail[-ctx_samps:] |
|
|
| |
| if ctx.shape[0] < ctx_samps: |
| pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32) |
| ctx = np.concatenate([pad, ctx], axis=0) |
| elif ctx.shape[0] > ctx_samps: |
| ctx = ctx[-ctx_samps:] |
|
|
| exact = au.Waveform(ctx, sr) |
| tokens_full = self.mrt.codec.encode(exact).astype(np.int32) |
| depth = int(self.mrt.config.decoder_codec_rvq_depth) |
| tokens = tokens_full[:, :depth] |
|
|
| |
| tokens = self._coerce_tokens(tokens) |
| return tokens |
|
|
| def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray: |
| """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps), |
| while ensuring the *end* of the audio lands on a bar boundary. |
| Strategy: take the largest integer number of bars <= ctx_seconds as the tail, |
| then left-fill from just before that tail (wrapping if needed) to reach exactly |
| ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim |
| tokens to the expected frame count. |
| """ |
| wav = loop.as_stereo().resample(self._model_sr) |
| data = wav.samples.astype(np.float32, copy=False) |
| if data.ndim == 1: |
| data = data[:, None] |
|
|
| spb = self._bar_clock.seconds_per_bar() |
| ctx_sec = float(self._ctx_seconds) |
| sr = int(self._model_sr) |
|
|
| |
| bars_fit = max(1, int(ctx_sec // spb)) |
| tail_len_samps = int(round(bars_fit * spb * sr)) |
|
|
| |
| need = int(round(ctx_sec * sr)) + tail_len_samps |
| if data.shape[0] == 0: |
| data = np.zeros((1, 2), dtype=np.float32) |
| reps = int(np.ceil(need / float(data.shape[0]))) |
| tiled = np.tile(data, (reps, 1)) |
|
|
| end = tiled.shape[0] |
| tail = tiled[end - tail_len_samps:end] |
|
|
| |
| ctx_samps = int(round(ctx_sec * sr)) |
| pad_len = ctx_samps - tail.shape[0] |
| if pad_len > 0: |
| pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps] |
| ctx = np.concatenate([pre, tail], axis=0) |
| else: |
| ctx = tail[-ctx_samps:] |
|
|
| |
| if ctx.shape[0] < ctx_samps: |
| pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32) |
| ctx = np.concatenate([pad, ctx], axis=0) |
| elif ctx.shape[0] > ctx_samps: |
| ctx = ctx[-ctx_samps:] |
|
|
| exact = au.Waveform(ctx, sr) |
| tokens_full = self.mrt.codec.encode(exact).astype(np.int32) |
| depth = int(self.mrt.config.decoder_codec_rvq_depth) |
| tokens = tokens_full[:, :depth] |
|
|
| |
| frames = tokens.shape[0] |
| exp = int(self._ctx_frames) |
| if frames < exp: |
| |
| pad = np.repeat(tokens[-1:, :], exp - frames, axis=0) |
| tokens = np.concatenate([pad, tokens], axis=0) |
| elif frames > exp: |
| tokens = tokens[-exp:, :] |
| return tokens |
|
|
|
|
| def _install_context_from_loop(self, loop: au.Waveform): |
| |
| context_tokens = self._encode_exact_context_tokens(loop) |
| s = self.mrt.init_state() |
| s.context_tokens = context_tokens |
| self.state = s |
| self._original_context_tokens = np.copy(context_tokens) |
|
|
| def reseed_from_waveform(self, wav: au.Waveform): |
| """Immediate reseed: replace context from provided wave (bar-locked, exact length).""" |
| context_tokens = self._encode_exact_context_tokens(wav) |
| with self._lock: |
| s = self.mrt.init_state() |
| s.context_tokens = context_tokens |
| self.state = s |
| self._model_stream = None |
| self._original_context_tokens = np.copy(context_tokens) |
|
|
| def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float): |
| """Queue a *seamless* reseed by token splicing instead of full restart. |
| We compute a fresh, bar-locked context token tensor of exact length |
| (e.g., 250 frames), then splice only the *tail* corresponding to |
| `anchor_bars` so generation continues smoothly without resetting state. |
| """ |
| new_ctx = self._encode_exact_context_tokens(recent_wav) |
| F, D = self._expected_token_shape() |
|
|
| |
| spb = self._bar_clock.seconds_per_bar() |
| frames_per_bar = max(1, int(round(self._codec_fps * spb))) |
| splice_frames = max(1, min(int(round(max(1.0, float(anchor_bars)) * frames_per_bar)), F)) |
|
|
| with self._lock: |
| |
| cur = getattr(self.state, "context_tokens", None) |
| if cur is None: |
| |
| self._pending_reseed = {"ctx": new_ctx} |
| return |
| cur = self._coerce_tokens(cur) |
|
|
| |
| left = cur[:F - splice_frames, :] |
| right = new_ctx[F - splice_frames:, :] |
| spliced = np.concatenate([left, right], axis=0) |
| spliced = self._coerce_tokens(spliced) |
|
|
| |
| self._pending_token_splice = { |
| "tokens": spliced, |
| "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar} |
| } |
| |
|
|
|
|
| def reseed_from_waveform(self, wav: au.Waveform): |
| """Immediate reseed: replace context from provided wave (bar-aligned tail).""" |
| wav = wav.as_stereo().resample(self._model_sr) |
| tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds) |
| tokens_full = self.mrt.codec.encode(tail).astype(np.int32) |
| depth = int(self.mrt.config.decoder_codec_rvq_depth) |
| context_tokens = tokens_full[:, :depth] |
|
|
| s = self.mrt.init_state() |
| s.context_tokens = context_tokens |
| self.state = s |
| |
| self._model_stream = None |
|
|
| |
|
|
| |
| self._original_context_tokens = np.copy(context_tokens) |
|
|
| |
|
|
| def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None: |
| """ |
| Append one MagentaRT chunk into the target-SR spool with an energy-aware, |
| deferred-overwrite crossfade to avoid writing near-silence at bar edges. |
| |
| Key behavior: |
| - Append BODY and TAIL of *this* chunk right away (resampled to target SR). |
| - Keep THIS chunk's model-rate TAIL (+ its target-SR length if appended) to repair the |
| previous boundary on the *next* call by mixing (prev_tail*cos + new_head*sin). |
| - When the correction length Lpop would be 0 (e.g., tail produced no target samples last time), |
| we APPEND the mixed-overlap to bridge the gap instead of overwriting 0 samples. |
| - Before overwriting/appending the mixed-overlap, we guard against writing ultra-quiet audio |
| by normalizing it up (bounded) if it's >20 dB below the existing spool end. |
| |
| This keeps your bar clock and external timing the same, but removes "bad starts" and fizzles. |
| """ |
| import math |
| import numpy as np |
|
|
| |
| def _rms_dbfs(x: np.ndarray) -> float: |
| if x.size == 0: |
| return -120.0 |
| if x.ndim == 2 and x.shape[1] > 1: |
| x_m = x.mean(axis=1, dtype=np.float32) |
| else: |
| x_m = x.astype(np.float32, copy=False).reshape(-1) |
| |
| x_m = np.nan_to_num(x_m, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) |
| r = float(np.sqrt(np.mean(x_m * x_m) + 1e-12)) |
| return 20.0 * math.log10(max(r, 1e-12)) |
|
|
| def _rms_dbfs_model(x: np.ndarray) -> float: |
| |
| return _rms_dbfs(x) |
|
|
| def to_target(y: np.ndarray) -> np.ndarray: |
| return y if self._rs is None else self._rs.process(y, final=False) |
|
|
| |
| s = wav.samples.astype(np.float32, copy=False) |
| if s.ndim == 1: |
| s = s[:, None] |
| if s.shape[1] == 1: |
| |
| s = np.repeat(s, 2, axis=1) |
|
|
| n_samps = int(s.shape[0]) |
|
|
| |
| try: |
| xfade_s = float(self.mrt.config.crossfade_length) |
| except Exception: |
| xfade_s = 0.0 |
| xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr))) |
|
|
| |
| if xfade_n > 0 and n_samps >= (2 * xfade_n): |
| head_m = s[:xfade_n, :] |
| body_m = s[xfade_n:n_samps - xfade_n, :] |
| tail_m = s[n_samps - xfade_n:, :] |
| else: |
| |
| head_m = np.zeros((0, 2), dtype=np.float32) |
| body_m = s |
| tail_m = np.zeros((0, 2), dtype=np.float32) |
|
|
| |
| |
| |
| did_boundary_mix = False |
| if (self._pending_tail_model is not None) and (xfade_n > 0) and (n_samps >= xfade_n): |
| |
| tail_prev_m = self._pending_tail_model |
| head_now_m = head_m |
|
|
| |
| if tail_prev_m.shape[1] != 2: |
| if tail_prev_m.ndim == 1: |
| tail_prev_m = tail_prev_m[:, None] |
| tail_prev_m = np.repeat(tail_prev_m[:, :1], 2, axis=1) |
| if head_now_m.shape[1] != 2: |
| if head_now_m.ndim == 1: |
| head_now_m = head_now_m[:, None] |
| head_now_m = np.repeat(head_now_m[:, :1], 2, axis=1) |
|
|
| |
| tail_r = _rms_dbfs_model(tail_prev_m) |
| head_r = _rms_dbfs_model(head_now_m) |
| xfade_use = int(xfade_n) |
| if min(tail_r, head_r) < -45.0: |
| xfade_use = max(1, xfade_n // 4) |
|
|
| |
| Lm = min(xfade_use, tail_prev_m.shape[0], head_now_m.shape[0]) |
| if Lm > 0: |
| t = np.linspace(0.0, math.pi / 2.0, Lm, endpoint=False, dtype=np.float32)[:, None] |
| cosw = np.cos(t, dtype=np.float32) |
| sinw = np.sin(t, dtype=np.float32) |
| mixed_m = tail_prev_m[-Lm:, :] * cosw + head_now_m[:Lm, :] * sinw |
|
|
| |
| y_mixed = to_target(mixed_m) |
| Lcorr = int(y_mixed.shape[0]) |
|
|
| if Lcorr > 0: |
| |
| |
| Lpop = int(min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)) |
|
|
| if Lpop > 0: |
| |
| prev_end = self._spool[-Lpop:, :] |
| new_seg = y_mixed[-Lpop:, :] |
|
|
| prev_r = _rms_dbfs(prev_end) |
| new_r = _rms_dbfs(new_seg) |
|
|
| |
| if new_r < (prev_r - 20.0): |
| lift_db = max(0.0, min(20.0, (prev_r - 6.0) - new_r)) |
| scale = 10.0 ** (lift_db / 20.0) |
| new_seg = np.clip(new_seg * scale, -1.0, 1.0).astype(np.float32, copy=False) |
|
|
| self._spool[-Lpop:, :] = new_seg |
| print(f"[append] mixedOverlap len={Lpop} rms={_rms_dbfs(new_seg):+.1f} dBFS") |
| else: |
| |
| |
| self._spool = np.concatenate([self._spool, y_mixed], axis=0) |
| self._spool_written += int(y_mixed.shape[0]) |
| print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_rms_dbfs(y_mixed):+.1f} dBFS") |
|
|
| did_boundary_mix = True |
|
|
| |
| self._pending_tail_model = None |
| self._pending_tail_target_len = 0 |
|
|
| |
| |
| |
| |
| y_body = to_target(body_m) if body_m.size else np.zeros((0, 2), dtype=np.float32) |
| if y_body.size: |
| self._spool = np.concatenate([self._spool, y_body], axis=0) |
| self._spool_written += int(y_body.shape[0]) |
| print(f"[append] body len={y_body.shape[0] if y_body.size else 0} rms={_rms_dbfs(y_body):+.1f} dBFS") |
|
|
| |
| y_tail = to_target(tail_m) if tail_m.size else np.zeros((0, 2), dtype=np.float32) |
| if y_tail.size: |
| self._spool = np.concatenate([self._spool, y_tail], axis=0) |
| self._spool_written += int(y_tail.shape[0]) |
| self._pending_tail_target_len = int(y_tail.shape[0]) |
| else: |
| |
| self._pending_tail_target_len = 0 |
| print(f"[append] tail len={y_tail.shape[0] if y_tail.size else 0} rms={_rms_dbfs(y_tail):+.1f} dBFS") |
|
|
| |
| |
| self._pending_tail_model = tail_m if tail_m.size else None |
|
|
|
|
|
|
|
|
| def _should_generate_next_chunk(self) -> bool: |
| |
| |
| implicit_consumed = self._next_to_deliver - 1 |
| horizon_anchor = max(self._last_consumed_index, implicit_consumed) |
| return self.idx <= (horizon_anchor + self._max_buffer_ahead) |
|
|
| def _emit_ready(self): |
| """Emit next chunk(s) if the spool has enough samples.""" |
| while True: |
| start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk) |
| if end > self._spool_written: |
| break |
| loop = self._spool[start:end] |
|
|
| |
| if self.params.loudness_mode != "none" and self.params.combined_loop is not None: |
| sr = int(self.params.target_sr) |
|
|
| |
| comb = self.params.combined_loop.as_stereo().resample(sr).samples.astype(np.float32, copy=False) |
| if comb.ndim == 1: |
| comb = comb[:, None] |
| if comb.shape[1] == 1: |
| comb = np.repeat(comb, 2, axis=1) |
|
|
| |
| |
| need = end - start |
| if comb.shape[0] > 0 and need > 0: |
| s = start % comb.shape[0] |
| if s + need <= comb.shape[0]: |
| ref_slice = comb[s:s+need] |
| else: |
| part1 = comb[s:] |
| part2 = comb[:max(0, need - part1.shape[0])] |
| ref_slice = np.vstack([part1, part2]) |
|
|
| ref = au.Waveform(ref_slice, sr) |
| tgt = au.Waveform(loop.copy(), sr) |
|
|
| matched, _stats = match_loudness_to_reference( |
| ref, tgt, |
| method=self.params.loudness_mode, |
| headroom_db=self.params.headroom_db |
| ) |
| loop = matched.samples |
|
|
| audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr)) |
| meta = { |
| "bpm": float(self.params.bpm), |
| "bars": int(self.params.bars_per_chunk), |
| "beats_per_bar": int(self.params.beats_per_bar), |
| "sample_rate": int(self.params.target_sr), |
| "channels": int(channels), |
| "total_samples": int(total_samples), |
| "seconds_per_bar": self._bar_clock.seconds_per_bar(), |
| "loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(), |
| "guidance_weight": float(self.params.guidance_weight), |
| "temperature": float(self.params.temperature), |
| "topk": int(self.params.topk), |
| } |
| chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta) |
|
|
| if os.getenv("MRT_DEBUG_RMS", "0") == "1": |
| spb = self._bar_clock.bar_samps |
| seg = int(max(1, spb // 4)) |
| |
| rms = [float(np.sqrt(np.mean(loop[i:i+seg]**2))) for i in range(0, loop.shape[0], seg)] |
| print(f"[emit idx={self.idx}] quarter-bar RMS: {rms[:8]}") |
|
|
| with self._cv: |
| self._outbox[self.idx] = chunk |
| self._cv.notify_all() |
| self.idx += 1 |
|
|
| |
| with self._lock: |
| |
| if self._pending_token_splice is not None: |
| spliced = self._coerce_tokens(self._pending_token_splice["tokens"]) |
| try: |
| |
| self.state.context_tokens = spliced |
| self._pending_token_splice = None |
| except Exception: |
| |
| new_state = self.mrt.init_state() |
| new_state.context_tokens = spliced |
| self.state = new_state |
| self._model_stream = None |
| self._pending_token_splice = None |
| elif self._pending_reseed is not None: |
| ctx = self._coerce_tokens(self._pending_reseed["ctx"]) |
| new_state = self.mrt.init_state() |
| new_state.context_tokens = ctx |
| self.state = new_state |
| self._model_stream = None |
| self._pending_reseed = None |
|
|
| |
|
|
| def run(self): |
| |
| while not self._stop_event.is_set(): |
| |
| if not self._should_generate_next_chunk(): |
| |
| self._emit_ready() |
| time.sleep(0.01) |
| continue |
|
|
| |
| |
| with self._lock: |
| target = self.params.style_vec |
| if target is None: |
| style_to_use = None |
| else: |
| if self._style_vec is None: |
| self._style_vec = np.array(target, dtype=np.float32, copy=True) |
| else: |
| ramp = float(self.params.style_ramp_seconds or 0.0) |
| step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp) |
| |
| self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec) |
| style_to_use = self._style_vec |
|
|
| |
| self.last_chunk_started_at = time.time() |
| self.last_activity_at = time.time() |
|
|
| wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use) |
|
|
| self.last_chunk_completed_at = time.time() |
| self.last_activity_at = time.time() |
|
|
| |
| self._append_model_chunk_and_spool(wav) |
| |
| self._emit_ready() |
|
|
| |
| tail = self._rs.process(np.zeros((0,2), np.float32), final=True) |
| if tail.size: |
| self._spool = np.concatenate([self._spool, tail], axis=0) |
| self._spool_written += tail.shape[0] |
| |
| self._emit_ready() |
|
|