Merge remote-tracking branch 'origin/feat/plugins' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-04 15:40:39 +08:00
285 changed files with 8052 additions and 1912 deletions

View File

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -37,7 +40,7 @@ class ModelConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
"""
Validate and set defaults for model config

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Mapping, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
@ -33,16 +34,7 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict | str, None, None]: ...
_dialogue_count: int
@overload
def generate(
@ -50,10 +42,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[True] = True,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
@ -61,20 +53,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict[str, Any], Generator[dict | str, None, None]]: ...
streaming: Literal[False] = False,
) -> Mapping: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str | dict, None, None]:
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@ -145,7 +148,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
@ -161,12 +164,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream,
stream=streaming,
)
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -195,7 +204,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
query="",
files=[],
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
@ -212,7 +221,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
stream=stream,
stream=streaming,
)
def _generate(
@ -224,7 +233,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -248,6 +257,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
db.session.commit()
db.session.refresh(conversation)
# get conversation dialogue count
self._dialogue_count = get_thread_messages_length(conversation.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
@ -318,6 +330,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
)
runner.run()
@ -371,6 +384,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message=message,
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
)
try:

View File

@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
) -> None:
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
def run(self) -> None:
app_config = self.application_generate_entity.app_config
@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,

View File

@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message: Message,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param message: message
:param user: user
:param stream: stream
:param dialogue_count: dialogue count
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
self.total_tokens: int = 0
def process(self):
"""
@ -358,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@ -371,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation.id,

View File

@ -1,5 +1,6 @@
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -85,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
"""
Validate for agent chat app model config

View File

@ -2,7 +2,7 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@ -32,36 +32,37 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict | str, None, None]: ...
streaming: Literal[True] = True,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False] = False,
) -> Mapping: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: bool = False,
) -> dict | Generator[dict | str, None, None]: ...
streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict | str, None, None]]:
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.
@ -71,7 +72,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not stream:
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get("query"):
@ -102,7 +103,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id,
config=args["model_config"],
)
# always enable retriever resource in debugger mode
@ -147,7 +149,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
@ -189,7 +191,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@ -141,7 +141,7 @@ class BaseAppGenerator:
return value
@classmethod
def convert_to_event_stream(cls, generator: Union[dict, Generator[dict | str, None, None]]):
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
@ -151,7 +151,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, dict):
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True] = True,
) -> Generator[dict | str, None, None]: ...
@overload
@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
streaming: Literal[False] = False,
) -> dict: ...
@overload
@ -56,7 +56,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = False,
streaming: bool = False,
) -> Union[dict, Generator[dict | str, None, None]]: ...
def generate(
@ -65,7 +65,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
streaming: bool = True,
) -> Union[dict, Generator[dict | str, None, None]]:
"""
Generate App response.
@ -152,7 +152,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=stream,
stream=streaming,
)
# init generate records
@ -172,7 +172,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@ -189,7 +189,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@ -2,7 +2,7 @@ import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from typing import Any, Literal, Mapping, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
@ -44,24 +44,24 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False] = False,
) -> Mapping: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: bool = False,
) -> dict | Generator[str, None, None]: ...
streaming: bool = False,
) -> Mapping | Generator[str, None, None]: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[str, None, None]]:
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[Mapping, Generator[str, None, None]]:
"""
Generate App response.
@ -129,7 +129,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
@ -168,7 +168,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -226,7 +226,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
) -> Union[Mapping, Generator[str, None, None]]:
"""
Generate App response.

View File

@ -33,15 +33,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[dict | str, None, None]: ...
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
@ -49,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
streaming: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> dict: ...
) -> Mapping: ...
@overload
def generate(
@ -62,11 +63,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping,
invoke_from: InvokeFrom,
stream: bool = False,
streaming: bool,
call_depth: int = 0,
) -> dict | Generator[dict | str, None, None]: ...
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping | Generator[Mapping | str, None, None]: ...
def generate(
self,
@ -75,10 +77,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: bool = True,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
) -> Mapping | Generator[Mapping | str, None, None]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
@ -113,7 +115,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
),
files=system_files,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
@ -130,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@ -142,7 +144,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[dict, Generator[str | dict, None, None]]:
"""
@ -184,13 +186,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: dict,
streaming: bool = True,
) -> dict[str, Any] | Generator[str | dict, Any, None]:
"""
Generate App response.
@ -218,7 +226,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs={},
files=[],
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
@ -235,7 +243,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
stream=stream,
streaming=streaming,
)
def _generate_worker(

View File

@ -106,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -319,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@ -332,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,

View File

@ -43,7 +43,7 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@ -138,7 +138,8 @@ class WorkflowBasedAppRunner(AppRunner):
# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(

View File

@ -1,9 +1,9 @@
import logging
import time
import uuid
from collections.abc import Callable, Generator
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Optional, Union
from typing import Any, Optional, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
@ -88,20 +88,17 @@ class RateLimit:
def gen_request_key() -> str:
return str(uuid.uuid4())
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
if isinstance(generator, dict):
def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str):
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(self, generator, request_id)
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str):
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()
else:
self.generator = generator
self.generator = generator
self.request_id = request_id
self.closed = False

View File

@ -340,7 +340,7 @@ class WorkflowCycleManage:
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,

View File

@ -7,13 +7,13 @@ from .models import (
)
__all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]

View File

@ -53,8 +53,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
if stream:
return response.iter_bytes()
return response
else:
logging.warning(

View File

@ -15,6 +15,5 @@ class SuggestedQuestionsAfterAnswerOutputParser:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
print(f"Could not parse LLM output: {text}")
return json_obj

View File

@ -18,25 +18,25 @@ from .message_entities import (
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageRole",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"LLMUsage",
"ModelPropertyKey",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"ToolPromptMessage",
"UserPromptMessage",
"VideoPromptMessageContent",
]

View File

@ -0,0 +1,32 @@
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
def get_thread_messages_length(conversation_id: str) -> int:
"""
Get the number of thread messages based on the parent message id.
"""
# Fetch all messages related to the conversation
query = (
db.session.query(
Message.id,
Message.parent_message_id,
Message.answer,
)
.filter(
Message.conversation_id == conversation_id,
)
.order_by(Message.created_at.desc())
)
messages = query.all()
# Extract thread messages
thread_messages = extract_thread_messages(messages)
# Exclude the newly created message with an empty answer
if thread_messages and not thread_messages[0].answer:
thread_messages.pop(0)
return len(thread_messages)

View File

@ -110,8 +110,12 @@ class RetrievalService:
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
@ -178,7 +182,10 @@ class RetrievalService:
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:
@ -220,7 +227,10 @@ class RetrievalService:
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:

View File

@ -104,8 +104,7 @@ class OceanBaseVector(BaseVector):
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
@ -200,10 +199,10 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
host=dify_config.OCEANBASE_VECTOR_HOST or "",
port=dify_config.OCEANBASE_VECTOR_PORT or 0,
user=dify_config.OCEANBASE_VECTOR_USER or "",
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
),
)

View File

@ -230,7 +230,6 @@ class OracleVector(BaseVector):
except LookupError:
nltk.download("punkt")
nltk.download("stopwords")
print("run download")
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words("english")

View File

@ -64,7 +64,7 @@ class UpstashVector(BaseVector):
item_ids = []
for doc_id in ids:
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if id:
if ids:
item_ids += ids
self._delete_by_ids(ids=item_ids)
@ -95,9 +95,10 @@ class UpstashVector(BaseVector):
metadata = record.metadata
text = record.data
score = record.score
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
if metadata is not None and text is not None:
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -123,7 +124,7 @@ class UpstashVectorFactory(AbstractVectorFactory):
return UpstashVector(
collection_name=collection_name,
config=UpstashVectorConfig(
url=dify_config.UPSTASH_VECTOR_URL,
token=dify_config.UPSTASH_VECTOR_TOKEN,
url=dify_config.UPSTASH_VECTOR_URL or "",
token=dify_config.UPSTASH_VECTOR_TOKEN or "",
),
)

View File

@ -102,7 +102,8 @@ class CacheEmbedding(Embeddings):
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY

View File

@ -86,7 +86,7 @@ class WordExtractor(BaseExtractor):
image_count += 1
if rel.is_external:
url = rel.reltype
response = ssrf_proxy.get(url, stream=True)
response = ssrf_proxy.get(url)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
file_uuid = str(uuid.uuid4())

View File

@ -88,11 +88,11 @@ class WorkflowTool(Tool):
user=self._get_user(user_id),
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
stream=False,
streaming=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id,
)
assert isinstance(result, dict)
data = result.get("data", {})
if data.get("error"):

View File

@ -32,32 +32,32 @@ from .variables import (
)
__all__ = [
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayAnyVariable",
"ArrayFileSegment",
"ArrayFileVariable",
"ArrayNumberSegment",
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArraySegment",
"ArrayFileSegment",
"ArrayNumberSegment",
"ArrayObjectSegment",
"ArrayStringSegment",
"ArrayStringVariable",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
"FloatSegment",
"FloatVariable",
"IntegerSegment",
"IntegerVariable",
"NoneSegment",
"NoneVariable",
"ObjectSegment",
"ObjectVariable",
"SecretVariable",
"Segment",
"SegmentGroup",
"SegmentType",
"StringSegment",
"StringVariable",
"Variable",
]

View File

@ -2,16 +2,19 @@ from enum import StrEnum
class SegmentType(StrEnum):
NONE = "none"
NUMBER = "number"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
FILE = "file"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
FILE = "file"
ARRAY_FILE = "array[file]"
NONE = "none"
GROUP = "group"

View File

@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowLoggingCallback",
"WorkflowCallback",
"WorkflowLoggingCallback",
]

View File

@ -39,7 +39,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -65,7 +65,6 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
self.submit_count -= 1
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
@ -229,7 +228,8 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

View File

@ -1,4 +1,4 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
NodeType.VARIABLE_ASSIGNER,
}:
answer_dependencies[answer_node_id].append(source_node_id)
else:

View File

@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from .node import BaseNode
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"
class BaseIterationNodeData(BaseNodeData):

View File

@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = cast(GenericNodeData, node_data)
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

View File

@ -4,8 +4,8 @@ import json
import docx
import pandas as pd
import pypdfium2
import yaml
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
@ -113,7 +113,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt":
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
@ -237,15 +237,17 @@ def _extract_text_from_csv(file_content: bytes) -> str:
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
try:
df = pd.read_excel(io.BytesIO(file_content))
# Drop rows where all elements are NaN
df.dropna(how="all", inplace=True)
# Convert DataFrame to Markdown table
markdown_table = df.to_markdown(index=False)
excel_file = pd.ExcelFile(io.BytesIO(file_content))
markdown_table = ""
for sheet_name in excel_file.sheet_names:
try:
df = excel_file.parse(sheet_name=sheet_name)
df.dropna(how="all", inplace=True)
# Create Markdown table two times to separate tables with a newline
markdown_table += df.to_markdown(index=False) + "\n\n"
except Exception as e:
continue
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e

View File

@ -1,4 +1,4 @@
from .end_node import EndNode
from .entities import EndStreamParam
__all__ = ["EndStreamParam", "EndNode"]
__all__ = ["EndNode", "EndStreamParam"]

View File

@ -14,11 +14,11 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"

View File

@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes
from .types import NodeEvent
__all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunStreamChunkEvent",
"NodeEvent",
"ModelInvokeCompletedEvent",
]

View File

@ -1,4 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]

View File

@ -1,11 +1,9 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
@ -107,6 +105,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
@ -149,11 +148,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content = response.content
if is_file and content_type:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(content_type) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
@ -164,7 +158,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mapping = {
"tool_file_id": tool_file.id,
"type": FileType.IMAGE.value,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
}
file = file_factory.build_from_mapping(

View File

@ -117,7 +117,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@ -163,7 +163,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
thread_pool = graph_engine.workflow_thread_pool_mapping[graph_engine.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
@ -299,12 +300,13 @@ class IterationNode(BaseNode[IterationNodeData]):
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config

View File

@ -197,7 +197,6 @@ class LLMNode(BaseNode[LLMNodeData]):
)
return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -1,3 +1,5 @@
from collections.abc import Mapping
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode,
LATEST_VERSION = "latest"
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
}

View File

@ -235,7 +235,7 @@ class ParameterExtractorNode(LLMNode):
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str):
if not isinstance(text, str | None):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage

View File

@ -1,4 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]

View File

@ -1,8 +0,0 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
__all__ = [
"VariableAssignerNode",
"VariableAssignerData",
"WriteMode",
]

View File

@ -0,0 +1,4 @@
class VariableOperatorNodeError(Exception):
"""Base error type, don't use directly."""
pass

View File

@ -0,0 +1,19 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()

View File

@ -1,2 +0,0 @@
class VariableAssignerNodeError(Exception):
pass

View File

@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -1,40 +1,36 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableAssignerNodeError("conversation_id not found")
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Optional
from core.workflow.nodes.base import BaseNodeData
@ -12,8 +11,6 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
title: str = "Variable Assigner"
desc: Optional[str] = "Assign a value to a variable"
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@ -0,0 +1,11 @@
from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
}

View File

@ -0,0 +1,20 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
from .enums import InputType, Operation
class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
value: Any | None = None
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]

View File

@ -0,0 +1,18 @@
from enum import StrEnum
class Operation(StrEnum):
OVER_WRITE = "over-write"
CLEAR = "clear"
APPEND = "append"
EXTEND = "extend"
SET = "set"
ADD = "+="
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
class InputType(StrEnum):
VARIABLE = "variable"
CONSTANT = "constant"

View File

@ -0,0 +1,31 @@
from collections.abc import Sequence
from typing import Any
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from .enums import InputType, Operation
class OperationNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, operation: Operation, varialbe_type: str):
super().__init__(f"Operation {operation} is not supported for type {varialbe_type}")
class InputTypeNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, input_type: InputType, operation: Operation):
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
class VariableNotFoundError(VariableOperatorNodeError):
def __init__(self, *, variable_selector: Sequence[str]):
super().__init__(f"Variable {variable_selector} not found")
class InvalidInputValueError(VariableOperatorNodeError):
def __init__(self, *, value: Any):
super().__init__(f"Invalid input value {value}")
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")

View File

@ -0,0 +1,91 @@
from typing import Any
from core.variables import SegmentType
from .enums import Operation
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
match operation:
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
def is_variable_input_supported(*, operation: Operation):
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
return False
return True
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
Operation.ADD,
Operation.SUBTRACT,
Operation.MULTIPLY,
Operation.DIVIDE,
}
case _:
return False
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
return True
match variable_type:
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
return False
return True
case SegmentType.OBJECT:
return isinstance(value, dict)
# Array & Append
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
return isinstance(value, str | float | int | dict)
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
return isinstance(value, str)
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str) for item in value)
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case _:
return False

View File

@ -0,0 +1,159 @@
import json
from typing import Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
try:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type)
# Check if variable input is supported
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
# Check if constant input is supported
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
variable_type=variable.value_type, operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
if value is None:
raise VariableNotFoundError(variable_selector=item.value)
# Skip if value is NoneSegment
if value.value_type == SegmentType.NONE:
continue
item.value = value.value
# If set string / bytes / bytearray to object, try convert string to object.
if (
item.operation == Operation.SET
and variable.value_type == SegmentType.OBJECT
and isinstance(item.value, str | bytes | bytearray)
):
try:
item.value = json.loads(item.value)
except json.JSONDecodeError:
raise InvalidInputValueError(value=item.value)
# Check if input value is valid
if not helpers.is_input_value_valid(
variable_type=variable.value_type, operation=item.operation, value=item.value
):
raise InvalidInputValueError(value=item.value)
# ==================== Execution Part
updated_value = self._handle_item(
variable=variable,
operation=item.operation,
value=item.value,
)
variable = variable.model_copy(update={"value": updated_value})
updated_variables.append(variable)
except VariableOperatorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
error=str(e),
)
# Update variables
for variable in updated_variables:
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
variable=variable,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
)
def _handle_item(
self,
*,
variable: Variable,
operation: Operation,
value: Any,
):
match operation:
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:
return variable.value + value
case Operation.SET:
return value
case Operation.ADD:
return variable.value + value
case Operation.SUBTRACT:
return variable.value - value
case Operation.MULTIPLY:
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case _:
raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type)

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@ -145,11 +145,8 @@ class WorkflowEntry:
# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)