| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field, is_dataclass |
| from typing import List, Optional, Union |
|
|
|
|
| @dataclass |
| class MultitalkerTranscriptionConfig: |
| """ |
| Configuration for Multi-talker transcription with an ASR model and a diarization model. |
| """ |
| |
| diar_model: Optional[str] = None |
| diar_pretrained_name: Optional[str] = None |
| max_num_of_spks: Optional[int] = 4 |
| parallel_speaker_strategy: bool = True |
| masked_asr: bool = True |
| mask_preencode: bool = False |
| cache_gating: bool = True |
| cache_gating_buffer_size: int = 2 |
| single_speaker_mode: bool = False |
|
|
| |
| session_len_sec: float = -1 |
| num_workers: int = 8 |
| random_seed: Optional[int] = None |
| log: bool = True |
|
|
| |
| streaming_mode: bool = True |
| spkcache_len: int = 188 |
| spkcache_refresh_rate: int = 0 |
| fifo_len: int = 188 |
| chunk_len: int = 0 |
| chunk_left_context: int = 0 |
| chunk_right_context: int = 0 |
|
|
| |
| cuda: Optional[int] = None |
| allow_mps: bool = False |
| matmul_precision: str = "highest" |
|
|
| |
| asr_model: Optional[str] = None |
| device: str = 'cuda' |
| audio_file: Optional[str] = None |
| manifest_file: Optional[str] = None |
| use_amp: bool = True |
| debug_mode: bool = False |
| batch_size: int = 32 |
| chunk_size: int = -1 |
| shift_size: int = -1 |
| left_chunks: int = 2 |
| online_normalization: bool = False |
| output_path: Optional[str] = None |
| pad_and_drop_preencoded: bool = False |
| set_decoder: Optional[str] = None |
| att_context_size: Optional[List[int]] = field(default_factory=lambda: [70, 13]) |
| generate_realtime_scripts: bool = False |
|
|
| word_window: int = 50 |
| sent_break_sec: float = 30.0 |
| fix_prev_words_count: int = 5 |
| update_prev_words_sentence: int = 5 |
| left_frame_shift: int = -1 |
| right_frame_shift: int = 0 |
| min_sigmoid_val: float = 1e-2 |
| discarded_frames: int = 8 |
| print_time: bool = True |
| print_sample_indices: List[int] = field(default_factory=lambda: [0]) |
| colored_text: bool = True |
| real_time_mode: bool = False |
| print_path: Optional[str] = None |
|
|
| ignored_initial_frame_steps: int = 5 |
| verbose: bool = False |
|
|
| feat_len_sec: float = 0.01 |
| finetune_realtime_ratio: float = 0.01 |
|
|
| spk_supervision: str = "diar" |
| binary_diar_preds: bool = False |
|
|
| @staticmethod |
| def init_diar_model(cfg, diar_model): |
| |
| diar_model.streaming_mode = cfg.streaming_mode |
| diar_model.sortformer_modules.chunk_len = cfg.chunk_len if cfg.chunk_len > 0 else 6 |
| diar_model.sortformer_modules.spkcache_len = cfg.spkcache_len |
| diar_model.sortformer_modules.chunk_left_context = cfg.chunk_left_context |
| diar_model.sortformer_modules.chunk_right_context = cfg.chunk_right_context if cfg.chunk_right_context > 0 else 7 |
| diar_model.sortformer_modules.fifo_len = cfg.fifo_len |
| diar_model.sortformer_modules.log = cfg.log |
| diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate |
| return diar_model |