[Frontend][Docs] Transcription API streaming (#13301)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@ -3,12 +3,14 @@
|
||||
# imports for guided decoding tests
|
||||
import io
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from openai._base_client import AsyncAPIClient
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
@ -120,3 +122,73 @@ async def test_completion_endpoints():
|
||||
res = await client.completions.create(model=model_name, prompt="Hello")
|
||||
assert res.code == 400
|
||||
assert res.message == "The model does not support Completions API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response(winning_call):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
transcription = ""
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res_no_stream = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
response_format="json",
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
# Unfortunately this only works when the openai client is patched
|
||||
# to use streaming mode, not exposed in the transcription api.
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True))
|
||||
# Reconstruct from chunks and validate
|
||||
async for chunk in res:
|
||||
# just a chunk
|
||||
text = chunk.choices[0]['delta']['content']
|
||||
transcription += text
|
||||
|
||||
assert transcription == res_no_stream.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_options(winning_call):
|
||||
model_name = "openai/whisper-small"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
original_post = AsyncAPIClient.post
|
||||
|
||||
async def post_with_stream(*args, **kwargs):
|
||||
kwargs['stream'] = True
|
||||
return await original_post(*args, **kwargs)
|
||||
|
||||
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True,
|
||||
stream_include_usage=True,
|
||||
stream_continuous_usage_stats=True))
|
||||
final = False
|
||||
continuous = True
|
||||
async for chunk in res:
|
||||
if not len(chunk.choices):
|
||||
# final usage sent
|
||||
final = True
|
||||
else:
|
||||
continuous = continuous and hasattr(chunk, 'usage')
|
||||
assert final and continuous
|
||||
|
||||
Reference in New Issue
Block a user