Compare commits

...

1 Commits

Author SHA1 Message Date
d3eddd6ef1 initial
Signed-off-by: Roger Wang <ywang@roblox.com>
2025-04-01 16:06:59 -07:00
3 changed files with 395 additions and 52 deletions

View File

@ -67,6 +67,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeResponse, TokenizeResponse,
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
TranslationRequest,
TranslationResponse,
UnloadLoRAAdapterRequest) UnloadLoRAAdapterRequest)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@ -80,7 +82,7 @@ from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription) OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
with_cancellation) with_cancellation)
@ -383,6 +385,10 @@ def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription return request.app.state.openai_serving_transcription
def translation(request: Request) -> OpenAIServingTranslation:
return request.app.state.openai_serving_translation
def engine_client(request: Request) -> EngineClient: def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client return request.app.state.engine_client
@ -625,6 +631,31 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/audio/translations")
@with_cancellation
@load_aware_call
async def create_translations(request: Annotated[TranslationRequest,
Form()],
raw_request: Request):
handler = translation(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Translations API")
audio_data = await request.file.read()
generator = await handler.create_translation(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranslationResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call

View File

@ -1652,3 +1652,196 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
words: Optional[list[TranscriptionWord]] = None words: Optional[list[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps.""" """Extracted words and their corresponding timestamps."""
class TranslationResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class TranslationStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
object: Literal["translation.chunk"] = "translation.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranslationResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class TranslationRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranslation
file: UploadFile
"""
The audio file object (not file name) to translate, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: Optional[str] = None
"""ID of the model to use.
"""
language: Optional[str] = None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
"""The timestamp granularities to populate for this translation.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
# Default sampling parameters for translation requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError(
"Stream options can only be defined when `stream=True`.")
return data
# Translation response objects
class TranslationResponse(OpenAIBaseModel):
text: str
"""The translated text."""
class TranslationWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranslationSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranslationResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The translated text."""
segments: Optional[list[TranslationSegment]] = None
"""Segments of the translated text and their corresponding details."""
words: Optional[list[TranslationWord]] = None
"""Extracted words and their corresponding timestamps."""

View File

@ -4,7 +4,7 @@ import io
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from math import ceil from math import ceil
from typing import Final, Optional, Union, cast from typing import Callable, Optional, Union, cast
from fastapi import Request from fastapi import Request
@ -14,7 +14,8 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponseStreamChoice, TranscriptionResponse, TranscriptionResponseStreamChoice,
TranscriptionStreamResponse, UsageInfo) TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
@ -30,7 +31,7 @@ except ImportError:
logger = init_logger(__name__) logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages # From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
# TODO these configs should live somewhere with the model so we can support # TODO these configs should live somewhere with the model so we can support
# additional ones # additional ones
@ -144,16 +145,19 @@ ISO639_1_OTHER_LANGS = {
MAX_AUDIO_CLIP_FILESIZE_MB = 25 MAX_AUDIO_CLIP_FILESIZE_MB = 25
class OpenAIServingTranscription(OpenAIServing): class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
task_type: str = "transcribe", # or "translate"
): ):
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
@ -167,15 +171,16 @@ class OpenAIServingTranscription(OpenAIServing):
self.max_audio_clip_s = processor.feature_extractor.chunk_length self.max_audio_clip_s = processor.feature_extractor.chunk_length
self.model_sr = processor.feature_extractor.sampling_rate self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length self.hop_length = processor.feature_extractor.hop_length
self.task_type = task_type
if self.default_sampling_params: if self.default_sampling_params:
logger.info( logger.info(
"Overwriting default completion sampling param with: %s", "Overwriting default completion sampling param with: %s",
self.default_sampling_params) self.default_sampling_params)
async def _preprocess_transcription( async def _preprocess_speech_to_text(
self, self,
request: TranscriptionRequest, request: Union[TranscriptionRequest, TranslationRequest],
audio_data: bytes, audio_data: bytes,
) -> tuple[PromptType, float]: ) -> tuple[PromptType, float]:
# Validate request # Validate request
@ -218,21 +223,22 @@ class OpenAIServingTranscription(OpenAIServing):
}, },
}, },
"decoder_prompt": "decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" (f"<|startoftranscript|>{lang_token}"
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
} }
return cast(PromptType, prompt), duration return cast(PromptType, prompt), duration
# TODO (varun) : Make verbose response work ! async def _create_speech_to_text(
async def create_transcription( self,
self, audio_data: bytes, request: TranscriptionRequest, audio_data: bytes,
raw_request: Request request: Union[TranscriptionRequest, TranslationRequest],
) -> Union[TranscriptionResponse, AsyncGenerator[str, None], raw_request: Request,
ErrorResponse]: response_class: Union[TranscriptionResponse, TranslationResponse],
"""Transcription API similar to OpenAI's API. stream_generator_method: Callable,
) -> Union[Union[TranscriptionResponse, TranslationResponse],
See https://platform.openai.com/docs/api-reference/audio/createTranscription AsyncGenerator[str, None], ErrorResponse]:
for the API specification. This API mimics the OpenAI transcription API. """Base method for speech-to-text operations like transcription and
""" translation."""
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -247,7 +253,7 @@ class OpenAIServingTranscription(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"Currently only support response_format `text` or `json`") "Currently only support response_format `text` or `json`")
request_id = f"trsc-{self._base_request_id(raw_request)}" request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request: if raw_request:
@ -261,13 +267,14 @@ class OpenAIServingTranscription(OpenAIServing):
if lora_request: if lora_request:
return self.create_error_response( return self.create_error_response(
"Currently do not support LoRA for Transcription.") "Currently do not support LoRA for "
f"{self.task_type.title()}.")
if prompt_adapter_request: if prompt_adapter_request:
return self.create_error_response( return self.create_error_response(
"Currently do not support PromptAdapter for Transcription." f"Currently do not support PromptAdapter for "
) f"{self.task_type.title()}.")
prompt, duration_s = await self._preprocess_transcription( prompt, duration_s = await self._preprocess_speech_to_text(
request=request, request=request,
audio_data=audio_data, audio_data=audio_data,
) )
@ -300,31 +307,36 @@ class OpenAIServingTranscription(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
if request.stream: if request.stream:
return self.transcription_stream_generator(request, return stream_generator_method(request, result_generator,
result_generator, request_id, request_metadata,
request_id, duration_s)
request_metadata,
duration_s)
# Non-streaming response. # Non-streaming response.
try: try:
assert result_generator is not None assert result_generator is not None
async for op in result_generator: async for op in result_generator:
result = op result = op
return TranscriptionResponse(text=result.outputs[0].text) return response_class(text=result.outputs[0].text)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
async def transcription_stream_generator( async def _speech_to_text_stream_generator(
self, request: TranscriptionRequest, self,
result_generator: AsyncGenerator[RequestOutput, None], request: Union[TranscriptionRequest, TranslationRequest],
request_id: str, request_metadata: RequestResponseMetadata, result_generator: AsyncGenerator[RequestOutput, None],
audio_duration_s: float) -> AsyncGenerator[str, None]: request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: str,
response_stream_choice_class: Union[TranscriptionResponseStreamChoice,
TranslationResponseStreamChoice],
stream_response_class: Union[TranscriptionStreamResponse,
TranslationStreamResponse],
) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
model_name = request.model model_name = request.model
chunk_object_type: Final = "transcription.chunk"
completion_tokens = 0 completion_tokens = 0
num_prompt_tokens = 0 num_prompt_tokens = 0
@ -361,20 +373,20 @@ class OpenAIServingTranscription(OpenAIServing):
if output.finish_reason is None: if output.finish_reason is None:
# Still generating, send delta update. # Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice( choice_data = response_stream_choice_class(
delta=delta_message) delta=delta_message)
else: else:
# Model is finished generating. # Model is finished generating.
choice_data = TranscriptionResponseStreamChoice( choice_data = response_stream_choice_class(
delta=delta_message, delta=delta_message,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
chunk = TranscriptionStreamResponse(id=request_id, chunk = stream_response_class(id=request_id,
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
# handle usage stats if requested & if continuous # handle usage stats if requested & if continuous
if include_continuous_usage: if include_continuous_usage:
@ -395,7 +407,7 @@ class OpenAIServingTranscription(OpenAIServing):
total_tokens=num_prompt_tokens + total_tokens=num_prompt_tokens +
completion_tokens) completion_tokens)
final_usage_chunk = TranscriptionStreamResponse( final_usage_chunk = stream_response_class(
id=request_id, id=request_id,
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
@ -414,8 +426,115 @@ class OpenAIServingTranscription(OpenAIServing):
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.") logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe")
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranscriptionResponse,
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
return await self._speech_to_text_stream_generator(
request=request,
result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate")
async def create_translation(
self, audio_data: bytes, request: TranslationRequest,
raw_request: Request
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranslationResponse,
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self, request: TranslationRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
return await self._speech_to_text_stream_generator(
request=request,
result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)