refactor: enforce return object in app generator

This commit is contained in:
Yeuoly
2024-08-29 19:49:57 +08:00
parent a073de44e9
commit ec711d094d
14 changed files with 140 additions and 71 deletions

View File

@ -48,6 +48,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream: Literal[False] = False,
) -> dict: ...
@overload
def generate(
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = False,
) -> dict | Generator[str, None, None]: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from typing import cast
@ -51,7 +50,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
-> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -76,11 +75,11 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
-> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -111,4 +110,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk