Compare commits
1 Commits
main
...
whisper-tr
| Author | SHA1 | Date | |
|---|---|---|---|
| d3eddd6ef1 |
@ -67,6 +67,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
TranslationRequest,
|
||||||
|
TranslationResponse,
|
||||||
UnloadLoRAAdapterRequest)
|
UnloadLoRAAdapterRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
@ -80,7 +82,7 @@ from vllm.entrypoints.openai.serving_score import ServingScores
|
|||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
OpenAIServingTranscription)
|
OpenAIServingTranscription, OpenAIServingTranslation)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||||
with_cancellation)
|
with_cancellation)
|
||||||
@ -383,6 +385,10 @@ def transcription(request: Request) -> OpenAIServingTranscription:
|
|||||||
return request.app.state.openai_serving_transcription
|
return request.app.state.openai_serving_transcription
|
||||||
|
|
||||||
|
|
||||||
|
def translation(request: Request) -> OpenAIServingTranslation:
|
||||||
|
return request.app.state.openai_serving_translation
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> EngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
@ -625,6 +631,31 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
|||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/audio/translations")
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_translations(request: Annotated[TranslationRequest,
|
||||||
|
Form()],
|
||||||
|
raw_request: Request):
|
||||||
|
handler = translation(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Translations API")
|
||||||
|
|
||||||
|
audio_data = await request.file.read()
|
||||||
|
generator = await handler.create_translation(audio_data, request,
|
||||||
|
raw_request)
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
|
||||||
|
elif isinstance(generator, TranslationResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
@load_aware_call
|
@load_aware_call
|
||||||
|
|||||||
@ -1652,3 +1652,196 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
|||||||
|
|
||||||
words: Optional[list[TranscriptionWord]] = None
|
words: Optional[list[TranscriptionWord]] = None
|
||||||
"""Extracted words and their corresponding timestamps."""
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||||
|
object: Literal["translation.chunk"] = "translation.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[TranslationResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||||
|
|
||||||
|
file: UploadFile
|
||||||
|
"""
|
||||||
|
The audio file object (not file name) to translate, in one of these
|
||||||
|
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: Optional[str] = None
|
||||||
|
"""ID of the model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
language: Optional[str] = None
|
||||||
|
"""The language of the input audio.
|
||||||
|
|
||||||
|
Supplying the input language in
|
||||||
|
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||||
|
will improve accuracy and latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str = Field(default="")
|
||||||
|
"""An optional text to guide the model's style or continue a previous audio
|
||||||
|
segment.
|
||||||
|
|
||||||
|
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||||
|
should match the audio language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_format: AudioResponseFormat = Field(default="json")
|
||||||
|
"""
|
||||||
|
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||||
|
`verbose_json`, or `vtt`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||||
|
temperature: float = Field(default=0.0)
|
||||||
|
"""The sampling temperature, between 0 and 1.
|
||||||
|
|
||||||
|
Higher values like 0.8 will make the output more random, while lower values
|
||||||
|
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||||
|
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||||
|
to automatically increase the temperature until certain thresholds are hit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||||
|
alias="timestamp_granularities[]", default=[])
|
||||||
|
"""The timestamp granularities to populate for this translation.
|
||||||
|
|
||||||
|
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||||
|
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||||
|
There is no additional latency for segment timestamps, but generating word
|
||||||
|
timestamps incurs additional latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
"""Custom field not present in the original OpenAI definition. When set,
|
||||||
|
it will enable output to be streamed in a similar fashion as the Chat
|
||||||
|
Completion endpoint.
|
||||||
|
"""
|
||||||
|
# Flattened stream option to simplify form data.
|
||||||
|
stream_include_usage: Optional[bool] = False
|
||||||
|
stream_continuous_usage_stats: Optional[bool] = False
|
||||||
|
|
||||||
|
# Default sampling parameters for translation requests.
|
||||||
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self,
|
||||||
|
default_max_tokens: int,
|
||||||
|
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||||
|
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
# Default parameters
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
output_kind=RequestOutputKind.DELTA
|
||||||
|
if self.stream \
|
||||||
|
else RequestOutputKind.FINAL_ONLY)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||||
|
raise ValueError(
|
||||||
|
"Stream options can only be defined when `stream=True`.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# Translation response objects
|
||||||
|
class TranslationResponse(OpenAIBaseModel):
|
||||||
|
text: str
|
||||||
|
"""The translated text."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationWord(OpenAIBaseModel):
|
||||||
|
end: float
|
||||||
|
"""End time of the word in seconds."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the word in seconds."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
"""The text content of the word."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationSegment(OpenAIBaseModel):
|
||||||
|
id: int
|
||||||
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
|
avg_logprob: float
|
||||||
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compression_ratio: float
|
||||||
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end: float
|
||||||
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
|
no_speech_prob: float
|
||||||
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
|
this segment silent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seek: int
|
||||||
|
"""Seek offset of the segment."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the segment in seconds."""
|
||||||
|
|
||||||
|
temperature: float
|
||||||
|
"""Temperature parameter used for generating the segment."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Text content of the segment."""
|
||||||
|
|
||||||
|
tokens: list[int]
|
||||||
|
"""Array of token IDs for the text content."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||||
|
duration: str
|
||||||
|
"""The duration of the input audio."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
"""The language of the input audio."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""The translated text."""
|
||||||
|
|
||||||
|
segments: Optional[list[TranslationSegment]] = None
|
||||||
|
"""Segments of the translated text and their corresponding details."""
|
||||||
|
|
||||||
|
words: Optional[list[TranslationWord]] = None
|
||||||
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import io
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Final, Optional, Union, cast
|
from typing import Callable, Optional, Union, cast
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
@ -14,7 +14,8 @@ from vllm.entrypoints.logger import RequestLogger
|
|||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
|
||||||
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
TranscriptionResponse, TranscriptionResponseStreamChoice,
|
||||||
TranscriptionStreamResponse, UsageInfo)
|
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
|
||||||
|
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
@ -30,7 +31,7 @@ except ImportError:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
|
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||||
# TODO these configs should live somewhere with the model so we can support
|
# TODO these configs should live somewhere with the model so we can support
|
||||||
# additional ones
|
# additional ones
|
||||||
|
|
||||||
@ -144,16 +145,19 @@ ISO639_1_OTHER_LANGS = {
|
|||||||
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
MAX_AUDIO_CLIP_FILESIZE_MB = 25
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingTranscription(OpenAIServing):
|
class OpenAISpeechToText(OpenAIServing):
|
||||||
|
"""Base class for speech-to-text operations like transcription and
|
||||||
|
translation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
models: OpenAIServingModels,
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
task_type: str = "transcribe", # or "translate"
|
||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -167,15 +171,16 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
self.max_audio_clip_s = processor.feature_extractor.chunk_length
|
||||||
self.model_sr = processor.feature_extractor.sampling_rate
|
self.model_sr = processor.feature_extractor.sampling_rate
|
||||||
self.hop_length = processor.feature_extractor.hop_length
|
self.hop_length = processor.feature_extractor.hop_length
|
||||||
|
self.task_type = task_type
|
||||||
|
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overwriting default completion sampling param with: %s",
|
"Overwriting default completion sampling param with: %s",
|
||||||
self.default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
async def _preprocess_transcription(
|
async def _preprocess_speech_to_text(
|
||||||
self,
|
self,
|
||||||
request: TranscriptionRequest,
|
request: Union[TranscriptionRequest, TranslationRequest],
|
||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
) -> tuple[PromptType, float]:
|
) -> tuple[PromptType, float]:
|
||||||
# Validate request
|
# Validate request
|
||||||
@ -218,21 +223,22 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"decoder_prompt":
|
"decoder_prompt":
|
||||||
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
|
(f"<|startoftranscript|>{lang_token}"
|
||||||
|
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
|
||||||
}
|
}
|
||||||
return cast(PromptType, prompt), duration
|
return cast(PromptType, prompt), duration
|
||||||
|
|
||||||
# TODO (varun) : Make verbose response work !
|
async def _create_speech_to_text(
|
||||||
async def create_transcription(
|
self,
|
||||||
self, audio_data: bytes, request: TranscriptionRequest,
|
audio_data: bytes,
|
||||||
raw_request: Request
|
request: Union[TranscriptionRequest, TranslationRequest],
|
||||||
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
raw_request: Request,
|
||||||
ErrorResponse]:
|
response_class: Union[TranscriptionResponse, TranslationResponse],
|
||||||
"""Transcription API similar to OpenAI's API.
|
stream_generator_method: Callable,
|
||||||
|
) -> Union[Union[TranscriptionResponse, TranslationResponse],
|
||||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
AsyncGenerator[str, None], ErrorResponse]:
|
||||||
for the API specification. This API mimics the OpenAI transcription API.
|
"""Base method for speech-to-text operations like transcription and
|
||||||
"""
|
translation."""
|
||||||
error_check_ret = await self._check_model(request)
|
error_check_ret = await self._check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -247,7 +253,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Currently only support response_format `text` or `json`")
|
"Currently only support response_format `text` or `json`")
|
||||||
|
|
||||||
request_id = f"trsc-{self._base_request_id(raw_request)}"
|
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||||
|
|
||||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
if raw_request:
|
if raw_request:
|
||||||
@ -261,13 +267,14 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
if lora_request:
|
if lora_request:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Currently do not support LoRA for Transcription.")
|
"Currently do not support LoRA for "
|
||||||
|
f"{self.task_type.title()}.")
|
||||||
if prompt_adapter_request:
|
if prompt_adapter_request:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Currently do not support PromptAdapter for Transcription."
|
f"Currently do not support PromptAdapter for "
|
||||||
)
|
f"{self.task_type.title()}.")
|
||||||
|
|
||||||
prompt, duration_s = await self._preprocess_transcription(
|
prompt, duration_s = await self._preprocess_speech_to_text(
|
||||||
request=request,
|
request=request,
|
||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
)
|
)
|
||||||
@ -300,31 +307,36 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.transcription_stream_generator(request,
|
return stream_generator_method(request, result_generator,
|
||||||
result_generator,
|
request_id, request_metadata,
|
||||||
request_id,
|
duration_s)
|
||||||
request_metadata,
|
|
||||||
duration_s)
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
assert result_generator is not None
|
assert result_generator is not None
|
||||||
async for op in result_generator:
|
async for op in result_generator:
|
||||||
result = op
|
result = op
|
||||||
return TranscriptionResponse(text=result.outputs[0].text)
|
return response_class(text=result.outputs[0].text)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
async def transcription_stream_generator(
|
async def _speech_to_text_stream_generator(
|
||||||
self, request: TranscriptionRequest,
|
self,
|
||||||
result_generator: AsyncGenerator[RequestOutput, None],
|
request: Union[TranscriptionRequest, TranslationRequest],
|
||||||
request_id: str, request_metadata: RequestResponseMetadata,
|
result_generator: AsyncGenerator[RequestOutput, None],
|
||||||
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
request_id: str,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
audio_duration_s: float,
|
||||||
|
chunk_object_type: str,
|
||||||
|
response_stream_choice_class: Union[TranscriptionResponseStreamChoice,
|
||||||
|
TranslationResponseStreamChoice],
|
||||||
|
stream_response_class: Union[TranscriptionStreamResponse,
|
||||||
|
TranslationStreamResponse],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
chunk_object_type: Final = "transcription.chunk"
|
|
||||||
|
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
@ -361,20 +373,20 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
if output.finish_reason is None:
|
if output.finish_reason is None:
|
||||||
# Still generating, send delta update.
|
# Still generating, send delta update.
|
||||||
choice_data = TranscriptionResponseStreamChoice(
|
choice_data = response_stream_choice_class(
|
||||||
delta=delta_message)
|
delta=delta_message)
|
||||||
else:
|
else:
|
||||||
# Model is finished generating.
|
# Model is finished generating.
|
||||||
choice_data = TranscriptionResponseStreamChoice(
|
choice_data = response_stream_choice_class(
|
||||||
delta=delta_message,
|
delta=delta_message,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
|
|
||||||
chunk = TranscriptionStreamResponse(id=request_id,
|
chunk = stream_response_class(id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=model_name)
|
model=model_name)
|
||||||
|
|
||||||
# handle usage stats if requested & if continuous
|
# handle usage stats if requested & if continuous
|
||||||
if include_continuous_usage:
|
if include_continuous_usage:
|
||||||
@ -395,7 +407,7 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
total_tokens=num_prompt_tokens +
|
total_tokens=num_prompt_tokens +
|
||||||
completion_tokens)
|
completion_tokens)
|
||||||
|
|
||||||
final_usage_chunk = TranscriptionStreamResponse(
|
final_usage_chunk = stream_response_class(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
@ -414,8 +426,115 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
logger.exception("Error in chat completion stream generator.")
|
logger.exception("Error in %s stream generator.", self.task_type)
|
||||||
data = self.create_streaming_error_response(str(e))
|
data = self.create_streaming_error_response(str(e))
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
# Send the final done message after all response.n are finished
|
# Send the final done message after all response.n are finished
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||||
|
"""Handles transcription requests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
models: OpenAIServingModels,
|
||||||
|
*,
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||||
|
task_type="transcribe")
|
||||||
|
|
||||||
|
async def create_transcription(
|
||||||
|
self, audio_data: bytes, request: TranscriptionRequest,
|
||||||
|
raw_request: Request
|
||||||
|
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
|
||||||
|
ErrorResponse]:
|
||||||
|
"""Transcription API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||||
|
for the API specification. This API mimics the OpenAI transcription API.
|
||||||
|
"""
|
||||||
|
return await self._create_speech_to_text(
|
||||||
|
audio_data=audio_data,
|
||||||
|
request=request,
|
||||||
|
raw_request=raw_request,
|
||||||
|
response_class=TranscriptionResponse,
|
||||||
|
stream_generator_method=self.transcription_stream_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def transcription_stream_generator(
|
||||||
|
self, request: TranscriptionRequest,
|
||||||
|
result_generator: AsyncGenerator[RequestOutput, None],
|
||||||
|
request_id: str, request_metadata: RequestResponseMetadata,
|
||||||
|
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||||
|
return await self._speech_to_text_stream_generator(
|
||||||
|
request=request,
|
||||||
|
result_generator=result_generator,
|
||||||
|
request_id=request_id,
|
||||||
|
request_metadata=request_metadata,
|
||||||
|
audio_duration_s=audio_duration_s,
|
||||||
|
chunk_object_type="transcription.chunk",
|
||||||
|
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||||
|
stream_response_class=TranscriptionStreamResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||||
|
"""Handles translation requests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
models: OpenAIServingModels,
|
||||||
|
*,
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||||
|
task_type="translate")
|
||||||
|
|
||||||
|
async def create_translation(
|
||||||
|
self, audio_data: bytes, request: TranslationRequest,
|
||||||
|
raw_request: Request
|
||||||
|
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
|
||||||
|
"""Translation API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||||
|
for the API specification. This API mimics the OpenAI translation API.
|
||||||
|
"""
|
||||||
|
return await self._create_speech_to_text(
|
||||||
|
audio_data=audio_data,
|
||||||
|
request=request,
|
||||||
|
raw_request=raw_request,
|
||||||
|
response_class=TranslationResponse,
|
||||||
|
stream_generator_method=self.translation_stream_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def translation_stream_generator(
|
||||||
|
self, request: TranslationRequest,
|
||||||
|
result_generator: AsyncGenerator[RequestOutput, None],
|
||||||
|
request_id: str, request_metadata: RequestResponseMetadata,
|
||||||
|
audio_duration_s: float) -> AsyncGenerator[str, None]:
|
||||||
|
return await self._speech_to_text_stream_generator(
|
||||||
|
request=request,
|
||||||
|
result_generator=result_generator,
|
||||||
|
request_id=request_id,
|
||||||
|
request_metadata=request_metadata,
|
||||||
|
audio_duration_s=audio_duration_s,
|
||||||
|
chunk_object_type="translation.chunk",
|
||||||
|
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||||
|
stream_response_class=TranslationStreamResponse,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user