Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN-
2025-09-10 02:03:45 +08:00
99 changed files with 849 additions and 494 deletions

View File

@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager:
if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)

View File

@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append(
AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

View File

@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -169,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -297,13 +297,13 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
@ -594,10 +594,10 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event(
self,
@ -650,7 +650,7 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
@ -846,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
@ -902,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
@ -914,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return True
else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, Optional, cast
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)
if not isinstance(config["agent_mode"]["enabled"], bool):
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not agent_mode.get("tools"):
agent_mode["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]:
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key

View File

@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -32,6 +32,7 @@ class AppQueueManager:
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(

View File

@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")

View File

@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -133,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -142,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -272,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs

View File

@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_run_id: Optional[str] = None
query: str
@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str
class SingleIterationRunEntity(BaseModel):

View File

@ -5,7 +5,6 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -90,9 +89,6 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ErrorStreamResponse(StreamResponse):
"""
@ -685,9 +681,6 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ChatbotAppBlockingResponse(AppBlockingResponse):
"""

View File

@ -35,6 +35,9 @@ class AnnotationReplyFeature:
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name

View File

@ -1 +1,3 @@
from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance

View File

@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline:
return message
def _error_to_stream_response(self, e: Exception):
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline:
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
completion, flagged = self._output_moderation_handler.moderation_completion(
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
self.output_moderation_handler = None
if flagged:
return completion

View File

@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self.ping_stream_response()
else:
continue
if publisher:
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self.output_moderation_handler.append_new_token(text)
return False