mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
refactor app
This commit is contained in:
282
api/core/app/app_queue_manager.py
Normal file
282
api/core/app/app_queue_manager.py
Normal file
@ -0,0 +1,282 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.entities.queue_entities import (
|
||||
AnnotationReplyEvent,
|
||||
AppQueueEvent,
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
"""
|
||||
# wait for 10 minutes to stop listen
|
||||
listen_timeout = 600
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self.stop_listen()
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentMessageEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
:param text: text
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageReplaceEvent(
|
||||
text=text
|
||||
), pub_from)
|
||||
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish retriever resources
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
|
||||
|
||||
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish annotation reply
|
||||
:param message_annotation_id: message annotation id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
|
||||
|
||||
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message end
|
||||
:param llm_result: llm result
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_agent_thought: message agent thought
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=message_agent_thought.id
|
||||
), pub_from)
|
||||
|
||||
def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_file: message file
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), pub_from)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = QueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate task belong cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_belong:{task_id}"
|
||||
|
||||
@classmethod
|
||||
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate stopped cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_stopped:{task_id}"
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
Reference in New Issue
Block a user