diff --git a/docs/.nav.yml b/docs/.nav.yml index 327755582a..c103ed476d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -44,6 +44,7 @@ nav: - contributing/model/registration.md - contributing/model/tests.md - contributing/model/multimodal.md + - contributing/model/transcription.md - CI: contributing/ci - Design Documents: design - API Reference: diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 0ca77fa499..6c013738ac 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide: - [Registering a Model](registration.md) - [Unit Testing](tests.md) - [Multi-Modal Support](multimodal.md) +- [Speech-to-Text Support](transcription.md) !!! tip If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md new file mode 100644 index 0000000000..bba85adf55 --- /dev/null +++ b/docs/contributing/model/transcription.md @@ -0,0 +1,273 @@ +# Speech-to-Text (Transcription/Translation) Support + +This document walks you through the steps to add support for speech-to-text (ASR) models to vLLM’s transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription]. +Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance. + +## 1. Update the base vLLM model + +It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods. + +- Declare supported languages and capabilities: + + ??? code + + ```python + from typing import ClassVar, Mapping, Optional, Literal + import numpy as np + import torch + from torch import nn + + from vllm.config import ModelConfig, SpeechToTextConfig + from vllm.inputs.data import PromptType + from vllm.model_executor.models.interfaces import SupportsTranscription + + class YourASRModel(nn.Module, SupportsTranscription): + # Map of ISO 639-1 language codes to language names + supported_languages: ClassVar[Mapping[str, str]] = { + "en": "English", + "it": "Italian", + # ... add more as needed + } + + # If your model only supports audio-conditioned generation + # (no text-only generation), enable this flag. + supports_transcription_only: ClassVar[bool] = True + ``` + + - The `supported_languages` mapping is validated at init time. + - Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper). + +- Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config]. + This is for controlling general behavior of the API when serving your model: + + ??? code + + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_speech_to_text_config( + cls, + model_config: ModelConfig, + task_type: Literal["transcribe", "translate"], + ) -> SpeechToTextConfig: + return SpeechToTextConfig( + sample_rate=16_000, + max_audio_clip_s=30, + # Set to None to disable server-side chunking if your + # model/processor handles it already + min_energy_split_window_size=None, + ) + ``` + + See the “Audio preprocessing and chunking” section for what each field controls. + +- Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: + + #### A. Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) + + Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`: + + ??? code + + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + # Example with a free-form instruction prompt + task_word = "Transcribe" if task_type == "transcribe" else "Translate" + prompt = ( + "user\n" + f"{task_word} this audio: " + "\nmodel\n" + ) + + return { + "multi_modal_data": {"audio": (audio, stt_config.sample_rate)}, + "prompt": prompt, + } + ``` + + For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md). + + #### B. Encoder–decoder audio-only (e.g., Whisper) + + Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: + + ??? code + + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + if language is None: + raise ValueError("Language must be specified") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), + } + return cast(PromptType, prompt) + ``` + + +- (Optional) Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language] + + If your model requires a language and you want a default, override this method (see Whisper): + + ```python + @classmethod + def validate_language(cls, language: Optional[str]) -> Optional[str]: + if language is None: + logger.warning( + "Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.") + language = "en" + return super().validate_language(language) + ``` + +- (Optional) Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens] + + Provide a fast duration→token estimate to improve streaming usage statistics: + + ??? code + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: + # Return None if unknown; otherwise return an estimate. + return int(audio_duration_s * stt_config.sample_rate // 320) # example + ``` + + +## 2. Audio preprocessing and chunking + +The API server takes care of basic audio I/O and optional chunking before building prompts: + +- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`. +- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`. +- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words. + +Relevant server logic: +??? code + ```python + # vllm/entrypoints/openai/speech_to_text.py + async def _preprocess_speech_to_text(...): + language = self.model_cls.validate_language(request.language) + ... + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + duration = librosa.get_duration(y=y, sr=sr) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + model_config=self.model_config, + language=language, + task_type=self.task_type, + request_prompt=request.prompt, + to_language=to_language, + ) + prompts.append(prompt) + return prompts, duration + ``` + +## 3. Exposing tasks automatically + +- vLLM automatically advertises transcription support if your model implements the interface: + +```python +if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + supported_tasks.append("transcription") +``` + +- When enabled, the server initializes the transcription and translation handlers: + +```python +state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None +state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None +``` + +No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`. + +## 4. Examples in-tree + +- Whisper encoder–decoder (audio-only): +- Voxtral decoder-only (audio embeddings + LLM): +- Gemma3n decoder-only with fixed instruction prompt: + +## 5. Test with the API + +Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI): + +- Transcription (ASR): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/transcriptions + ``` + +- Translation (source → English unless otherwise supported): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/translations + ``` +Or check out more examples in . + +!!! note + +- If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking. +- Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass. +- For multilingual behavior, keep `supported_languages` aligned with actual model capabilities. \ No newline at end of file