Compare commits

..

25 Commits

Author SHA1 Message Date
d34c95bf8e feat: convert components to dynamic imports for improved performance 2025-07-17 15:32:42 +08:00
1df1ffa2ec fix(apps): add translation and document title for Apps component 2025-07-17 14:08:00 +08:00
947dbd8854 chore(apps): move app list components to components folder 2025-07-17 14:00:29 +08:00
ce619287b3 feat(app-publisher): add relative time formatting for timestamps 2025-07-16 18:34:03 +08:00
cd1ec65286 feat: convert components to dynamic imports for improved performance 2025-07-16 16:51:55 +08:00
qfl
bdb9f29948 feat(app): support custom max_active_requests per app (#22073) 2025-07-16 15:31:19 +08:00
66cc1b4308 feat(variable-list): add drag-and-drop functionality for variables in code node (#22127) 2025-07-16 15:24:19 +08:00
d52fb18457 feat: auto-fill MCP server description with app description #22443 (#22477) 2025-07-16 15:03:33 +08:00
4a2169bd5f Chore/update gh template (#22480) 2025-07-16 14:22:51 +08:00
2c9ee54a16 fix aliyun trace session_id (#22468) 2025-07-16 13:56:44 +08:00
aef67ed7ec fix: add background color for chat bubble in light and dark themes (#22472) 2025-07-16 13:36:51 +08:00
ddfd8c8525 feat(api): add UUIDv7 implementation in SQL and Python (#22058)
This PR introduces UUIDv7 implementations in both Python and SQL to establish the foundation for migrating from UUIDv4 to UUIDv7 as proposed in #19754.

ID generation algorithm of existing models are not changed, and new models should use UUIDv7 for ID generation.

Close #19754.
2025-07-16 13:07:08 +08:00
2c1ab4879f refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
2025-07-16 12:31:37 +08:00
229b4d621e Improve Tooltip UX by enabling delay by default (#21383) 2025-07-16 11:26:54 +08:00
0dee41c074 fix: When var value changed, PromptEditor should be reset (#22219) 2025-07-16 11:22:54 +08:00
bf542233a9 minor fix: using Pydantic model_validate instead of deprecated parse_obj (#22239)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-16 10:57:08 +08:00
38106074b4 test: add comprehensive unit tests for console authentication and authorization decorators (#22439) 2025-07-16 10:07:01 +08:00
znn
1f4b3591ae adding tooltip for bindingCount (#22450)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-07-16 09:59:42 +08:00
7bf3d2c8bf fix(api): Fix potential thread leak in MCP BaseSession (#22169)
The `BaseSession` class in the `core/mcp/session` package uses `ThreadPoolExecutor` 
to run the receive loop but fails to properly clean up the executor and receiver 
future, leading to potential thread leaks.

This PR addresses this issue by:
- Initializing `_executor` and `_receiver_future` attributes to `None` for proper cleanup checks
- Adding graceful shutdown with a 5-second timeout in the `__exit__` method
- Ensuring the ThreadPoolExecutor is properly shut down to prevent resource leaks

This fix prevents memory leaks and hanging threads in long-running scenarios where 
multiple MCP sessions are created and destroyed.

Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-16 00:01:44 +08:00
da53bf511f chore: add SQLALCHEMY_POOL_USE_LIFO option and missing SQLALCHEMY_POOL_PRE_PING env default value. (#22371) 2025-07-15 19:46:48 +08:00
7388fd1ec6 fix: Disable question editing in chat history (#22438) 2025-07-15 19:41:51 +08:00
b803eeb528 fix: Update condition items to support variable type acquisition (#22414) 2025-07-15 19:38:13 +08:00
14f79ee652 fix: create api workflow run repository error (#22422) 2025-07-15 16:12:02 +08:00
df89629e04 fix: conversatino statistic including data from debugger (#22412)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-07-15 15:45:45 +08:00
d427088ab5 fix: remove PickerPanel padding (#22419) 2025-07-15 15:37:13 +08:00
188 changed files with 4392 additions and 3987 deletions

View File

@ -8,13 +8,13 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue否则会被关闭。谢谢:)"
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
@ -42,20 +42,22 @@ body:
attributes:
label: Steps to reproduce
description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks.
placeholder: Having detailed steps helps us reproduce the bug.
placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: What were you expecting?
description: Describe what you expected to happen.
placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here.
validations:
required: false
required: true
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: What happened instead?
description: Describe what actually happened.
placeholder: What happened instead? Please do not copy and paste the steps to reproduce here.
validations:
required: false

View File

@ -1,5 +1,11 @@
blank_issues_enabled: false
contact_links:
- name: "\U0001F4A1 Model Providers & Plugins"
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
- name: "\U0001F4AC Documentation Issues"
url: "https://github.com/langgenius/dify-docs/issues/new"
about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue.
- name: "\U0001F4E7 Discussions"
url: https://github.com/langgenius/dify/discussions/categories/general
about: General discussions and request help from the community
about: General discussions and seek help from the community

View File

@ -1,24 +0,0 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation
labels:
- documentation
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue否则会被关闭。谢谢:)"
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:
label: Provide a description of requested docs changes
placeholder: Briefly describe which document needs to be corrected and why.
validations:
required: true

View File

@ -8,11 +8,11 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue否则会被关闭。谢谢:)"
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true

View File

@ -1,55 +0,0 @@
name: "🌐 Localization/Translation issue"
description: Report incorrect translations. [please use English :)]
labels:
- translation
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue否则会被关闭。谢谢:)"
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: input
attributes:
label: Dify version
description: Hover over system tray icon or look at Settings
validations:
required: true
- type: input
attributes:
label: Utility with translation issue
placeholder: Some area
description: Please input here the utility with the translation issue
validations:
required: true
- type: input
attributes:
label: 🌐 Language affected
placeholder: "German"
validations:
required: true
- type: textarea
attributes:
label: ❌ Actual phrase(s)
placeholder: What is there? Please include a screenshot as that is extremely helpful.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected phrase(s)
placeholder: What was expected?
validations:
required: true
- type: textarea
attributes:
label: Why is the current translation wrong
placeholder: Why do you feel this is incorrect?
validations:
required: true

View File

@ -162,6 +162,11 @@ class DatabaseConfig(BaseSettings):
default=3600,
)
SQLALCHEMY_POOL_USE_LIFO: bool = Field(
description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.",
default=False,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
default=False,
@ -199,6 +204,7 @@ class DatabaseConfig(BaseSettings):
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
}

View File

@ -151,6 +151,7 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args()
app_service = AppService()

View File

@ -35,16 +35,20 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_fields)
def post(self, app_model):
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
args = parser.parse_args()
description = args.get("description")
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=args["description"],
description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
@ -65,14 +69,22 @@ class AppMCPServerController(Resource):
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()
server.description = args["description"]
description = args.get("description")
if description is None:
pass
elif not description:
server.description = app_model.description or ""
else:
server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:

View File

@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
from models import AppMode, Message
class DailyMessageStatistic(Resource):
@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
)
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at < end_datetime_utc)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
stmt = stmt.group_by("date").order_by("date")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
rs = conn.execute(stmt, {"tz": account.timezone})
for row in rs:
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
return jsonify({"data": response_data})

View File

@ -68,13 +68,18 @@ def _create_pagination_parser():
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value_type": v.value_type.exposed_type().value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,

View File

@ -16,9 +16,10 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not workflow:
raise ValueError("Workflow not initialized")
user_id = None
user_id: str | None = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
}
system_inputs = SystemVariable(
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph

View File

@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
},
workflow_system_variables=SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),

View File

@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
}
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,

View File

@ -54,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
@ -107,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
},
workflow_system_variables=SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),

View File

@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)

View File

@ -240,7 +240,7 @@ def refresh_authorization(
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.parse_obj(response.json())
return OAuthTokens.model_validate(response.json())
def register_client(

View File

@ -1,7 +1,7 @@
import logging
import queue
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
@ -171,23 +171,41 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
# The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
# ensures no unnecessary threads are created.
self._executor = ThreadPoolExecutor(max_workers=1)
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def check_receiver_status(self) -> None:
if self._receiver_future.done():
"""`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done():
self._receiver_future.result()
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._read_stream.put(None)
self._write_stream.put(None)
# Wait for the receiver loop to finish
if self._receiver_future:
try:
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
def send_request(
self,
request: SendRequestT,

View File

@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance):
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
except Exception:
except Exception as e:
logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True)
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",

View File

@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform):
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
if self.with_variable_tmpl:
vp = VariablePool()
vp = VariablePool.empty()
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)

View File

@ -1,9 +1,9 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel):
"""Segment is runtime type used during the execution of workflow.
Note: this class is abstract, you should use subclasses of this class instead.
"""
model_config = ConfigDict(frozen=True)
value_type: SegmentType
@ -73,7 +78,7 @@ class StringSegment(Segment):
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.FLOAT
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
@ -92,7 +97,7 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.INTEGER
value: int
@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
@property
def text(self) -> str:
return ""
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
elif isinstance(v, dict):
value_type = v.get("value_type")
if value_type is None:
return None
try:
seg_type = SegmentType(value_type)
except ValueError:
return None
return seg_type
else:
# return None if the discriminator value isn't found
return None
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Segment` for type hinting when serialization is not required.
#
# Note:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -1,8 +1,27 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional
from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
# Skip element validation (only check array container)
NONE = "none"
# Validate the first element (if array is non-empty)
FIRST = "first"
# Validate all elements in the array.
ALL = "all"
class SegmentType(StrEnum):
NUMBER = "number"
INTEGER = "integer"
FLOAT = "float"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
@ -19,16 +38,141 @@ class SegmentType(StrEnum):
GROUP = "group"
def is_array_type(self):
def is_array_type(self) -> bool:
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
For example, this may occur if the input is a generic Python object of type `object`.
"""
if isinstance(value, list):
elem_types: set[SegmentType] = set()
for i in value:
segment_type = cls.infer_segment_type(i)
if segment_type is None:
return None
elem_types.add(segment_type)
if len(elem_types) != 1:
if elem_types.issubset(_NUMERICAL_TYPES):
return SegmentType.ARRAY_NUMBER
return SegmentType.ARRAY_ANY
elif all(i.is_array_type() for i in elem_types):
return SegmentType.ARRAY_ANY
match elem_types.pop():
case SegmentType.STRING:
return SegmentType.ARRAY_STRING
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return SegmentType.ARRAY_NUMBER
case SegmentType.OBJECT:
return SegmentType.ARRAY_OBJECT
case SegmentType.FILE:
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
elif isinstance(value, str):
return SegmentType.STRING
elif isinstance(value, dict):
return SegmentType.OBJECT
elif isinstance(value, File):
return SegmentType.FILE
elif isinstance(value, str):
return SegmentType.STRING
else:
return None
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
if not isinstance(value, list):
return False
# Skip element validation if array is empty
if len(value) == 0:
return True
if self == SegmentType.ARRAY_ANY:
return True
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
if array_validation == ArrayValidation.NONE:
return True
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
`isinstance` manually.
Args:
value: The value to validate
array_validation: Validation strategy for array types (ignored for non-array types)
Returns:
True if the value matches the type under the given validation strategy
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
elif self == SegmentType.NUMBER:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
return isinstance(value, str)
elif self == SegmentType.OBJECT:
return isinstance(value, dict)
elif self == SegmentType.SECRET:
return isinstance(value, str)
elif self == SegmentType.FILE:
return isinstance(value, File)
elif self == SegmentType.NONE:
return value is None
else:
raise AssertionError("this statement should be unreachable.")
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
"""
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
return SegmentType.NUMBER
return self
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
}
_ARRAY_TYPES = frozenset(
[
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)
_NUMERICAL_TYPES = frozenset(
[
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
]
)

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import cast
from typing import Annotated, TypeAlias, cast
from uuid import uuid4
from pydantic import Field
from pydantic import Discriminator, Field, Tag
from core.helper import encrypter
@ -20,6 +20,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -27,6 +28,10 @@ from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
It is mainly used to store segments and their selector in VariablePool.
Note: this class is abstract, you should use subclasses of this class instead.
"""
id: str = Field(
@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -1,7 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
@ -23,31 +24,31 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
system_variables: SystemVariable = Field(
description="System variables",
default_factory=dict,
)
environment_variables: Sequence[Variable] = Field(
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@ -83,8 +84,22 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = variable
key, hash_key = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@ -102,8 +117,8 @@ class VariablePool(BaseModel):
if len(selector) < MIN_SELECTORS_LENGTH:
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
key, hash_key = self._selector_to_keys(selector)
value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
@ -136,8 +151,9 @@ class VariablePool(BaseModel):
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
key, hash_key = self._selector_to_keys(selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
@ -154,3 +170,20 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment):
return segment
return None
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():
if value is None:
continue
selector = (SYSTEM_VARIABLE_NODE_ID, key)
# If the system variable already exists, do not add it again.
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
@classmethod
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())

View File

@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""

View File

@ -1,11 +1,29 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
class LoopVariableData(BaseModel):
"""
@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
"""
label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None

View File

@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment,
ObjectSegment,
Segment,
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value:
if value and isinstance(value, str):
value = json.loads(value)
else:
value = []
segment_info = segment_mapping.get(var_type)
if not segment_info:
raise ValueError(f"Invalid variable type: {var_type}")
segment_class, value_type = segment_info
return segment_class(value=value, value_type=value_type)
try:
return build_segment_with_type(var_type, value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
raise
try:
value = json.loads(value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

View File

@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.

View File

@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:

View File

@ -1,5 +1,6 @@
from core.variables import SegmentType
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,

View File

@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
return variable_type in {
SegmentType.OBJECT,
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:

View File

@ -0,0 +1,89 @@
from collections.abc import Sequence
from typing import Any
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
from core.file.models import File
from core.workflow.enums import SystemVariableKey
class SystemVariable(BaseModel):
"""A model for managing system variables.
Fields with a value of `None` are treated as absent and will not be included
in the variable pool.
"""
model_config = ConfigDict(
extra="forbid",
serialize_by_alias=True,
validate_by_alias=True,
)
user_id: str | None = None
# Ideally, `app_id` and `workflow_id` should be required and not `None`.
# However, there are scenarios in the codebase where these fields are not set.
# To maintain compatibility, they are marked as optional here.
app_id: str | None = None
workflow_id: str | None = None
files: Sequence[File] = Field(default_factory=list)
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
# To maintain compatibility with existing workflows, it must be serialized
# as `workflow_run_id` in dictionaries or JSON objects, and also referenced
# as `workflow_run_id` in the variable pool.
workflow_execution_id: str | None = Field(
validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
serialization_alias="workflow_run_id",
default=None,
)
# Chatflow related fields.
query: str | None = None
conversation_id: str | None = None
dialogue_count: int | None = None
@model_validator(mode="before")
@classmethod
def validate_json_fields(cls, data):
if isinstance(data, dict):
# For JSON validation, only allow workflow_run_id
if "workflow_execution_id" in data and "workflow_run_id" not in data:
# This is likely from direct instantiation, allow it
return data
elif "workflow_execution_id" in data and "workflow_run_id" in data:
# Both present, remove workflow_execution_id
data = data.copy()
data.pop("workflow_execution_id")
return data
return data
@classmethod
def empty(cls) -> "SystemVariable":
return cls()
def to_dict(self) -> dict[SystemVariableKey, Any]:
# NOTE: This method is provided for compatibility with legacy code.
# New code should use the `SystemVariable` object directly instead of converting
# it to a dictionary, as this conversion results in the loss of type information
# for each key, making static analysis more difficult.
d: dict[SystemVariableKey, Any] = {
SystemVariableKey.FILES: self.files,
}
if self.user_id is not None:
d[SystemVariableKey.USER_ID] = self.user_id
if self.app_id is not None:
d[SystemVariableKey.APP_ID] = self.app_id
if self.workflow_id is not None:
d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
if self.workflow_execution_id is not None:
d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
if self.query is not None:
d[SystemVariableKey.QUERY] = self.query
if self.conversation_id is not None:
d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
if self.dialogue_count is not None:
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
return d

View File

@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import (
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@ -43,7 +44,7 @@ class WorkflowCycleManager:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
@ -56,17 +57,22 @@ class WorkflowCycleManager:
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
continue
inputs[f"sys.{key.value}"] = value
# Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID:
continue
inputs[f"sys.{field_name}"] = value
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
execution_id = str(
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
) or str(uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,

View File

@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
@ -254,7 +255,7 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=[],
)

View File

@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType:
def build_segment(value: Any, /) -> Segment:
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
# below
if value is None:
return NoneSegment()
if isinstance(value, str):
@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, list):
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
if all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
elif len(types) != 1:
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
return ArrayNumberSegment(value=value)
return ArrayAnySegment(value=value)
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment:
raise ValueError(f"not supported value {value}")
_segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.NONE: NoneSegment,
SegmentType.STRING: StringSegment,
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
}
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.
@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
if segment_type == SegmentType.NONE:
return NoneSegment()
else:
raise TypeMismatchError(f"Expected {segment_type}, but got None")
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
# Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0:
@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value)
else:
raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
# Build segment using existing logic to infer actual type
inferred_segment = build_segment(value)
inferred_type = inferred_segment.value_type
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
)
if inferred_type == segment_type:
return inferred_segment
# Type mismatch - raise error with descriptive message
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but value '{value}' "
f"(type: {type(value).__name__}) corresponds to {inferred_type}"
)
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
elif segment_type == SegmentType.NUMBER and inferred_type in (
SegmentType.INTEGER,
SegmentType.FLOAT,
):
segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
def segment_to_variable(
@ -247,6 +278,6 @@ def segment_to_variable(
name=name,
description=description,
value=segment.value,
selector=selector,
selector=list(selector),
),
)

View File

@ -0,0 +1,15 @@
from typing import TypedDict
from core.variables.segments import Segment
from core.variables.types import SegmentType
class _VarTypedDict(TypedDict, total=False):
value_type: SegmentType
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
else:
return v["value_type"].exposed_type().value

View File

@ -188,6 +188,7 @@ app_detail_fields_with_site = {
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,

View File

@ -2,10 +2,12 @@ from flask_restful import fields
from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value_type": fields.String(attribute=serialize_value_type),
"value": fields.String,
"description": fields.String,
"created_at": TimestampField,

View File

@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.value,
"value_type": value.value_type.exposed_type().value,
"description": value.description,
}
if isinstance(value, dict):
value_type = value.get("value_type")
value_type_str = value.get("value_type")
if not isinstance(value_type_str, str):
raise TypeError(
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
)
value_type = SegmentType(value_type_str).exposed_type()
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value
@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw):
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw,
"description": fields.String,
}

164
api/libs/uuid_utils.py Normal file
View File

@ -0,0 +1,164 @@
import secrets
import struct
import time
import uuid
# Reference for UUIDv7 specification:
# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
#
# For details on the `struct.pack` format, refer to:
# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
_PACK_TIMESTAMP = ">Q"
# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
# into an unsigned 16-bit integer (big-endian).
_PACK_RAND_A = ">H"
def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
"""Create UUIDv7 byte structure with given timestamp and random bytes.
This is a private helper function that handles the common logic for creating
UUIDv7 byte structure according to RFC 9562 specification.
UUIDv7 Structure:
- 48 bits: timestamp (milliseconds since Unix epoch)
- 12 bits: random data A (with version bits)
- 62 bits: random data B (with variant bits)
The function performs the following operations:
1. Creates a 128-bit (16-byte) UUID structure
2. Packs the timestamp into the first 48 bits (6 bytes)
3. Sets the version bits to 7 (0111) in the correct position
4. Sets the variant bits to 10 (binary) in the correct position
5. Fills the remaining bits with the provided random bytes
Args:
timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
random_bytes: Random bytes to use for the random portions (must be 10 bytes).
First 2 bytes are used for random data A (12 bits after version).
Last 8 bytes are used for random data B (62 bits after variant).
Returns:
A 16-byte bytes object representing the complete UUIDv7 structure.
Note:
This function assumes the random_bytes parameter is exactly 10 bytes.
The caller is responsible for providing appropriate random data.
"""
# Create the 128-bit UUID structure
uuid_bytes = bytearray(16)
# Pack timestamp (48 bits) into first 6 bytes
uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
# Next 16 bits: random data A (12 bits) + version (4 bits)
# Take first 2 random bytes and set version to 7
rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
# Clear the highest 4 bits to make room for the version field
# by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
rand_a = rand_a & 0x0FFF
# Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
rand_a = rand_a | 0x7000
uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
# Last 64 bits: random data B (62 bits) + variant (2 bits)
# Use remaining 8 random bytes and set variant to 10 (binary)
uuid_bytes[8:16] = random_bytes[2:10]
# Set variant bits (first 2 bits of byte 8 should be '10')
uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
return bytes(uuid_bytes)
def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
"""Generate a UUID version 7 according to RFC 9562 specification.
UUIDv7 features a time-ordered value field derived from the widely
implemented and well known Unix Epoch timestamp source, the number of
milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
Structure:
- 48 bits: timestamp (milliseconds since Unix epoch)
- 12 bits: random data A (with version bits)
- 62 bits: random data B (with variant bits)
Args:
timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
Should be an integer representing milliseconds since Unix epoch.
Returns:
A UUID object representing a UUIDv7.
Example:
>>> import time
>>> # Generate UUIDv7 with current time
>>> uuid_current = uuidv7()
>>> # Generate UUIDv7 with specific timestamp
>>> uuid_specific = uuidv7(int(time.time() * 1000))
"""
if timestamp_ms is None:
timestamp_ms = int(time.time() * 1000)
# Generate 10 random bytes for the random portions
random_bytes = secrets.token_bytes(10)
# Create UUIDv7 bytes using the helper function
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
return uuid.UUID(bytes=uuid_bytes)
def uuidv7_timestamp(id_: uuid.UUID) -> int:
"""Extract the timestamp from a UUIDv7.
UUIDv7 contains a 48-bit timestamp field representing milliseconds since
the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
returns that timestamp as an integer representing milliseconds since the epoch.
Args:
id_: A UUID object that should be a UUIDv7 (version 7).
Returns:
The timestamp as an integer representing milliseconds since Unix epoch.
Raises:
ValueError: If the provided UUID is not version 7.
Example:
>>> uuid_v7 = uuidv7()
>>> timestamp = uuidv7_timestamp(uuid_v7)
>>> print(f"UUID was created at: {timestamp} ms")
"""
# Verify this is a UUIDv7
if id_.version != 7:
raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
# Extract the UUID bytes
uuid_bytes = id_.bytes
# Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
# Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
# Return timestamp directly in milliseconds as integer
assert isinstance(ts_in_ms, int)
return ts_in_ms
def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
"""Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
all random bits to 0. As the smallest possible uuidv7 for that timestamp,
it may be used as a boundary for partitions.
"""
# Use zero bytes for all random portions
zero_random_bytes = b"\x00" * 10
# Create UUIDv7 bytes using the helper function
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
return uuid.UUID(bytes=uuid_bytes)

View File

@ -0,0 +1,86 @@
"""add uuidv7 function in SQL
Revision ID: 1c9ba48be8e4
Revises: 58eb7bdb93fe
Create Date: 2025-07-02 23:32:38.484499
"""
"""
The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications.
LICENSE:
# Copyright and License
Copyright (c) 2024, Daniel Vérité
Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies.
In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage.
Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications.
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1c9ba48be8e4'
down_revision = '58eb7bdb93fe'
branch_labels: None = None
depends_on: None = None
def upgrade():
# This implementation differs slightly from the original uuidv7 function in
# https://github.com/dverite/postgres-uuidv7-sql/.
# The ability to specify source timestamp has been removed because its type signature is incompatible with
# PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
# generated and controlled within the application layer.
op.execute(sa.text(r"""
/* Main function to generate a uuidv7 value with millisecond precision */
CREATE FUNCTION uuidv7() RETURNS uuid
AS
$$
-- Replace the first 48 bits of a uuidv4 with the current
-- number of milliseconds since 1970-01-01 UTC
-- and set the "ver" field to 7 by setting additional bits
SELECT encode(
set_bit(
set_bit(
overlay(uuid_send(gen_random_uuid()) placing
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
3)
from 1 for 6),
52, 1),
53, 1), 'hex')::uuid;
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
COMMENT ON FUNCTION uuidv7 IS
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
"""))
op.execute(sa.text(r"""
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
AS
$$
/* uuid fields: version=0b0111, variant=0b10 */
SELECT encode(
overlay('\x00000000000070008000000000000000'::bytea
placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
from 1 for 6),
'hex')::uuid;
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
"""
))
def downgrade():
op.execute(sa.text("DROP FUNCTION uuidv7"))
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))

View File

@ -12,6 +12,7 @@ from sqlalchemy import orm
from core.file.constants import maybe_file_object
from core.file.models import File
from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -347,7 +348,7 @@ class Workflow(Base):
)
@property
def environment_variables(self) -> Sequence[Variable]:
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
@ -367,11 +368,15 @@ class Workflow(Base):
def decrypt_func(var):
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
else:
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
raise AssertionError("this statement should be unreachable.")
results = list(map(decrypt_func, results))
return results
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):

View File

@ -29,11 +29,12 @@ from sqlalchemy.orm import Session, sessionmaker
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
logger = logging.getLogger(__name__)
class DifyAPISQLAlchemyWorkflowRunRepository:
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
SQLAlchemy implementation of APIWorkflowRunRepository.

View File

@ -233,6 +233,7 @@ class AppService:
app.icon = args.get("icon")
app.icon_background = args.get("icon_background")
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
app.max_active_requests = args.get("max_active_requests")
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

View File

@ -3,7 +3,7 @@ import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from uuid import uuid4
from sqlalchemy import select
@ -15,10 +15,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@ -28,6 +28,7 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@ -369,7 +370,7 @@ class WorkflowService:
else:
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@ -685,36 +686,30 @@ def _setup_variable_pool(
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
# Create a variable pool.
system_inputs: dict[SystemVariableKey, Any] = {
# From inputs:
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
# From workflow model
SystemVariableKey.APP_ID: workflow.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
# Randomly generated.
SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()),
}
system_variable = SystemVariable(
user_id=user_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
files=files or [],
workflow_execution_id=str(uuid.uuid4()),
)
# Only add chatflow-specific variables for non-workflow types
if workflow.type != WorkflowType.WORKFLOW.value:
system_inputs.update(
{
SystemVariableKey.QUERY: query,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.DIALOGUE_COUNT: 0,
}
)
system_variable.query = query
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 0
else:
system_inputs = {}
system_variable = SystemVariable.empty()
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
)
return variable_pool

View File

@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@ -50,7 +50,7 @@ def init_code_node(code_config: dict):
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@ -44,7 +44,7 @@ def init_http_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -13,12 +13,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -62,12 +62,14 @@ def init_llm_node(config: dict) -> LLMNode:
# construct variable pool
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather today?",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
app_id=app_id,
workflow_id=workflow_id,
files=[],
query="what's the weather today?",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock):
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -44,7 +44,7 @@ def init_tool_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -88,6 +88,7 @@ def test_flask_configs(monkeypatch):
"pool_pre_ping": False,
"pool_recycle": 3600,
"pool_size": 30,
"pool_use_lifo": False,
}
assert config["CONSOLE_WEB_URL"] == "https://example.com"

View File

@ -0,0 +1,380 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
from controllers.console.workspace.error import AccountNotInitializedError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
enterprise_license_required,
only_edition_cloud,
only_edition_enterprise,
only_edition_self_hosted,
setup_required,
)
from models.account import AccountStatus
from services.feature_service import LicenseStatus
class MockUser(UserMixin):
"""Simple User class for testing."""
def __init__(self, user_id: str):
self.id = user_id
self.current_tenant_id = "tenant123"
def get_id(self) -> str:
return self.id
def create_app_with_login():
"""Create a Flask app with LoginManager configured."""
app = Flask(__name__)
app.config["SECRET_KEY"] = "test-secret-key"
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str):
return MockUser(user_id)
return app
class TestAccountInitialization:
"""Test account initialization decorator"""
def test_should_allow_initialized_account(self):
"""Test that initialized accounts can access protected views"""
# Arrange
mock_user = MagicMock()
mock_user.status = AccountStatus.ACTIVE
@account_initialization_required
def protected_view():
return "success"
# Act
with patch("controllers.console.wraps.current_user", mock_user):
result = protected_view()
# Assert
assert result == "success"
def test_should_reject_uninitialized_account(self):
"""Test that uninitialized accounts raise AccountNotInitializedError"""
# Arrange
mock_user = MagicMock()
mock_user.status = AccountStatus.UNINITIALIZED
@account_initialization_required
def protected_view():
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_user", mock_user):
with pytest.raises(AccountNotInitializedError):
protected_view()
class TestEditionChecks:
"""Test edition-specific decorators"""
def test_only_edition_cloud_allows_cloud_edition(self):
"""Test cloud edition decorator allows CLOUD edition"""
# Arrange
@only_edition_cloud
def cloud_view():
return "cloud_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
result = cloud_view()
# Assert
assert result == "cloud_success"
def test_only_edition_cloud_rejects_other_editions(self):
"""Test cloud edition decorator rejects non-CLOUD editions"""
# Arrange
app = Flask(__name__)
@only_edition_cloud
def cloud_view():
return "cloud_success"
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(Exception) as exc_info:
cloud_view()
assert exc_info.value.code == 404
def test_only_edition_enterprise_allows_when_enabled(self):
"""Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
# Arrange
@only_edition_enterprise
def enterprise_view():
return "enterprise_success"
# Act
with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
result = enterprise_view()
# Assert
assert result == "enterprise_success"
def test_only_edition_self_hosted_allows_self_hosted(self):
"""Test self-hosted edition decorator allows SELF_HOSTED edition"""
# Arrange
@only_edition_self_hosted
def self_hosted_view():
return "self_hosted_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
result = self_hosted_view()
# Assert
assert result == "self_hosted_success"
class TestBillingResourceLimits:
"""Test billing resource limit decorators"""
def test_should_allow_when_under_resource_limit(self):
"""Test that requests are allowed when under resource limits"""
# Arrange
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 5
@cloud_edition_billing_resource_check("members")
def add_member():
return "member_added"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
# Assert
assert result == "member_added"
def test_should_reject_when_over_resource_limit(self):
"""Test that requests are rejected when over resource limits"""
# Arrange
app = create_app_with_login()
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 10
@cloud_edition_billing_resource_check("members")
def add_member():
return "member_added"
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
assert exc_info.value.code == 403
assert "members has reached the limit" in str(exc_info.value.description)
def test_should_check_source_for_documents_limit(self):
"""Test document limit checks request source"""
# Arrange
app = create_app_with_login()
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.documents_upload_quota.limit = 100
mock_features.documents_upload_quota.size = 100
@cloud_edition_billing_resource_check("documents")
def upload_document():
return "document_uploaded"
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
assert exc_info.value.code == 403
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
class TestRateLimiting:
"""Test rate limiting decorator"""
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
"""Test that requests within rate limit are allowed"""
# Arrange
mock_rate_limit = MagicMock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 10
mock_redis.zcard.return_value = 5 # 5 requests in window
@cloud_edition_billing_rate_limit_check("knowledge")
def knowledge_request():
return "knowledge_success"
# Act
with patch("controllers.console.wraps.current_user"):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
result = knowledge_request()
# Assert
assert result == "knowledge_success"
mock_redis.zadd.assert_called_once()
mock_redis.zremrangebyscore.assert_called_once()
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
"""Test that requests over rate limit are rejected and logged"""
# Arrange
app = create_app_with_login()
mock_rate_limit = MagicMock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 10
mock_rate_limit.subscription_plan = "pro"
mock_redis.zcard.return_value = 11 # Over limit
mock_session = MagicMock()
mock_db.session = mock_session
@cloud_edition_billing_rate_limit_check("knowledge")
def knowledge_request():
return "knowledge_success"
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
with pytest.raises(Exception) as exc_info:
knowledge_request()
# Verify error
assert exc_info.value.code == 403
assert "rate limit" in str(exc_info.value.description)
# Verify rate limit log was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
class TestSystemSetup:
"""Test system setup decorator"""
@patch("controllers.console.wraps.db")
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():
return "admin_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
result = admin_view()
# Assert
assert result == "admin_success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_environ_get.return_value = "some_password"
@setup_required
def admin_view():
return "admin_success"
# Act & Assert
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(NotInitValidateError):
admin_view()
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_environ_get.return_value = None # No INIT_PASSWORD
@setup_required
def admin_view():
return "admin_success"
# Act & Assert
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(NotSetupError):
admin_view()
class TestEnterpriseLicense:
"""Test enterprise license decorator"""
def test_should_allow_with_valid_license(self):
"""Test that valid licenses allow access"""
# Arrange
mock_settings = MagicMock()
mock_settings.license.status = LicenseStatus.ACTIVE
@enterprise_license_required
def enterprise_feature():
return "enterprise_success"
# Act
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
result = enterprise_feature()
# Assert
assert result == "enterprise_success"
@pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
def test_should_reject_with_invalid_license(self, invalid_status):
"""Test that invalid licenses raise UnauthorizedAndForceLogout"""
# Arrange
mock_settings = MagicMock()
mock_settings.license.status = invalid_status
@enterprise_license_required
def enterprise_feature():
return "enterprise_success"
# Act & Assert
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
enterprise_feature()
assert "license is invalid" in str(exc_info.value)

View File

@ -1,14 +1,49 @@
import dataclasses
from pydantic import BaseModel
from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
from core.variables import SecretVariable, StringVariable
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
class _Segments(BaseModel):
segments: list[SegmentUnion]
class _Variables(BaseModel):
variables: list[VariableUnion]
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing"""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
class TestSegmentDumpAndLoad:
"""Test suite for segment and variable serialization/deserialization"""
def test_segments(self):
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_segment_number(self):
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_variables(self):
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
print("Json: ", json)
restored = _Variables.model_validate_json(json)
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
# Create one instance of each segment type
test_file = create_test_file()
all_segments: list[SegmentUnion] = [
NoneSegment(),
StringSegment(value="test string"),
IntegerSegment(value=42),
FloatSegment(value=3.14),
ObjectSegment(value={"key": "value", "number": 123}),
FileSegment(value=test_file),
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
ArrayStringSegment(value=["hello", "world"]),
ArrayNumberSegment(value=[1, 2.5, 3]),
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
for original, loaded_segment in zip(all_segments, loaded.segments):
assert type(loaded_segment) == type(original)
assert loaded_segment.value_type == original.value_type
# For file segments, compare key properties instead of exact equality
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
orig_file = original.value
loaded_file = loaded_segment.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),
FloatVariable(value=3.14, name="float_var"),
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
FileVariable(value=test_file, name="file_var"),
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
for original, loaded_variable in zip(all_variables, loaded.variables):
assert type(loaded_variable) == type(original)
assert loaded_variable.value_type == original.value_type
assert loaded_variable.name == original.name
# For file variables, compare key properties instead of exact equality
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
orig_file = original.value
loaded_file = loaded_variable.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_variable.value == original.value
def test_segment_discriminator_function_for_segment_types(self):
"""Test the segment discriminator function"""
@dataclasses.dataclass
class TestCase:
segment: Segment
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneSegment(),
SegmentType.NONE,
),
TestCase(
StringSegment(value=""),
SegmentType.STRING,
),
TestCase(
FloatSegment(value=0.0),
SegmentType.FLOAT,
),
TestCase(
IntegerSegment(value=0),
SegmentType.INTEGER,
),
TestCase(
ObjectSegment(value={}),
SegmentType.OBJECT,
),
TestCase(
FileSegment(value=file1),
SegmentType.FILE,
),
TestCase(
ArrayAnySegment(value=[0, 0.0, ""]),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringSegment(value=[""]),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberSegment(value=[0, 0.0]),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectSegment(value=[{}]),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileSegment(value=[file1, file2]),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
segment = test_case.segment
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(segment)}"
)
model_dict = segment.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
)
def test_variable_discriminator_function_for_variable_types(self):
"""Test the variable discriminator function"""
@dataclasses.dataclass
class TestCase:
variable: Variable
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneVariable(name="none_var"),
SegmentType.NONE,
),
TestCase(
StringVariable(value="test", name="string_var"),
SegmentType.STRING,
),
TestCase(
FloatVariable(value=0.0, name="float_var"),
SegmentType.FLOAT,
),
TestCase(
IntegerVariable(value=0, name="int_var"),
SegmentType.INTEGER,
),
TestCase(
ObjectVariable(value={}, name="object_var"),
SegmentType.OBJECT,
),
TestCase(
FileVariable(value=file1, name="file_var"),
SegmentType.FILE,
),
TestCase(
SecretVariable(value="secret", name="secret_var"),
SegmentType.SECRET,
),
TestCase(
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringVariable(value=[""], name="array_string_var"),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectVariable(value=[{}], name="array_object_var"),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
variable = test_case.variable
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(variable)}"
)
model_dict = variable.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invlaid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None

View File

@ -0,0 +1,60 @@
from core.variables.types import SegmentType
class TestSegmentTypeIsArrayType:
"""
Test class for SegmentType.is_array_type method.
Provides comprehensive coverage of all SegmentType values to ensure
correct identification of array and non-array types.
"""
def test_is_array_type(self):
"""
Test that all SegmentType enum values are covered in our test cases.
Ensures comprehensive coverage by verifying that every SegmentType
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
expected_non_array_types = [
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.NUMBER,
SegmentType.STRING,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
]
for seg_type in expected_array_types:
assert seg_type.is_array_type()
for seg_type in expected_non_array_types:
assert not seg_type.is_array_type()
# Act & Assert
covered_types = set(expected_array_types) | set(expected_non_array_types)
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
def test_all_enum_values_are_supported(self):
"""
Test that all enum values are supported and return boolean values.
Validates that every SegmentType enum value can be processed by
is_array_type method and returns a boolean value.
"""
enum_values: list[SegmentType] = list(SegmentType)
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"

View File

@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
def test_frozen_variables():
@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name="text", value="text")
var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@ -0,0 +1,146 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.system_variable import SystemVariable
def create_test_graph_runtime_state() -> GraphRuntimeState:
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
# Create a variable pool with system variables
system_vars = SystemVariable(
user_id="test_user_123",
app_id="test_app_456",
workflow_id="test_workflow_789",
workflow_execution_id="test_execution_001",
query="test query",
conversation_id="test_conv_123",
dialogue_count=5,
)
variable_pool = VariablePool(system_variables=system_vars)
# Add some variables to the variable pool
variable_pool.add(["test_node", "test_var"], "test_value")
variable_pool.add(["another_node", "another_var"], 42)
# Create LLM usage with realistic values
llm_usage = LLMUsage(
prompt_tokens=150,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.15"),
completion_tokens=75,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.15"),
total_tokens=225,
total_price=Decimal("0.30"),
currency="USD",
latency=1.25,
)
# Create runtime route state with some node states
node_run_state = RuntimeRouteState()
node_state = node_run_state.create_node_state("test_node_1")
node_run_state.add_route(node_state.id, "target_node_id")
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100,
llm_usage=llm_usage,
outputs={
"string_output": "test result",
"int_output": 42,
"float_output": 3.14,
"list_output": ["item1", "item2", "item3"],
"dict_output": {"key1": "value1", "key2": 123},
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
},
node_run_steps=5,
node_run_state=node_run_state,
)
def test_basic_round_trip_serialization():
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
# Create a state with non-empty values
original_state = create_test_graph_runtime_state()
# Serialize to JSON and deserialize back
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Core test: ensure the round-trip preserves all values
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="python")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="json")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
def test_outputs_field_round_trip():
"""Test the problematic outputs field maintains values through round-trip serialization."""
original_state = create_test_graph_runtime_state()
# Serialize and deserialize
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Verify the outputs field specifically maintains its values
assert deserialized_state.outputs == original_state.outputs
assert deserialized_state == original_state
def test_empty_outputs_round_trip():
"""Test round-trip serialization with empty outputs field."""
variable_pool = VariablePool.empty()
original_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs
)
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
assert deserialized_state == original_state
def test_llm_usage_round_trip():
# Create LLM usage with specific decimal values
llm_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.0015"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.15"),
completion_tokens=50,
completion_unit_price=Decimal("0.003"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.15"),
total_tokens=150,
total_price=Decimal("0.30"),
currency="USD",
latency=2.5,
)
json_data = llm_usage.model_dump_json()
deserialized = LLMUsage.model_validate_json(json_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="python")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="json")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage

View File

@ -0,0 +1,401 @@
import json
import uuid
from datetime import UTC, datetime
import pytest
from pydantic import ValidationError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
class TestRouteNodeStateSerialization:
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
def _test_route_node_state(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"input_key": "input_value"},
outputs={"output_key": "output_value"},
)
node_state = RouteNodeState(
node_id="comprehensive_test_node",
start_at=_TEST_DATETIME,
finished_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
node_run_result=node_run_result,
index=5,
paused_at=_TEST_DATETIME,
paused_by="user_123",
failed_reason="test_reason",
)
return node_state
def test_route_node_state_comprehensive_field_validation(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_state = self._test_route_node_state()
serialized = node_state.model_dump()
# Comprehensive validation of all RouteNodeState fields
assert serialized["node_id"] == "comprehensive_test_node"
assert serialized["status"] == RouteNodeState.Status.SUCCESS
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["finished_at"] == _TEST_DATETIME
assert serialized["paused_at"] == _TEST_DATETIME
assert serialized["paused_by"] == "user_123"
assert serialized["failed_reason"] == "test_reason"
assert serialized["index"] == 5
assert "id" in serialized
assert isinstance(serialized["id"], str)
uuid.UUID(serialized["id"]) # Validate UUID format
# Validate nested NodeRunResult structure
assert serialized["node_run_result"] is not None
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
def test_route_node_state_minimal_required_fields(self):
"""Test RouteNodeState with only required fields, focusing on defaults."""
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
serialized = node_state.model_dump()
# Focus on required fields and default values (not re-testing all fields)
assert serialized["node_id"] == "minimal_node"
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
assert serialized["index"] == 1 # Default index
assert serialized["node_run_result"] is None # Default None
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
def test_route_node_state_deserialization_from_dict(self):
"""Test RouteNodeState deserialization from dictionary data."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
test_id = str(uuid.uuid4())
dict_data = {
"id": test_id,
"node_id": "deserialized_node",
"start_at": test_datetime,
"status": "success",
"finished_at": test_datetime,
"index": 3,
}
node_state = RouteNodeState.model_validate(dict_data)
# Focus on deserialization accuracy
assert node_state.id == test_id
assert node_state.node_id == "deserialized_node"
assert node_state.start_at == test_datetime
assert node_state.status == RouteNodeState.Status.SUCCESS
assert node_state.finished_at == test_datetime
assert node_state.index == 3
def test_route_node_state_round_trip_consistency(self):
node_states = (
self._test_route_node_state(),
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
)
for node_state in node_states:
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="python")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="json")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
class TestRouteNodeStateEnumSerialization:
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
def test_status_enum_model_dump_behavior(self):
"""Test Status enum serialization in model_dump() returns enum objects."""
for status_enum in RouteNodeState.Status:
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
serialized = node_state.model_dump(mode="python")
assert serialized["status"] == status_enum
serialized = node_state.model_dump(mode="json")
assert serialized["status"] == status_enum.value
def test_status_enum_json_serialization_behavior(self):
"""Test Status enum serialization in JSON returns string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
enum_to_string_mapping = {
RouteNodeState.Status.RUNNING: "running",
RouteNodeState.Status.SUCCESS: "success",
RouteNodeState.Status.FAILED: "failed",
RouteNodeState.Status.PAUSED: "paused",
RouteNodeState.Status.EXCEPTION: "exception",
}
for status_enum, expected_string in enum_to_string_mapping.items():
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
json_data = json.loads(node_state.model_dump_json())
assert json_data["status"] == expected_string
def test_status_enum_deserialization_from_string(self):
"""Test Status enum deserialization from string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
string_to_enum_mapping = {
"running": RouteNodeState.Status.RUNNING,
"success": RouteNodeState.Status.SUCCESS,
"failed": RouteNodeState.Status.FAILED,
"paused": RouteNodeState.Status.PAUSED,
"exception": RouteNodeState.Status.EXCEPTION,
}
for status_string, expected_enum in string_to_enum_mapping.items():
dict_data = {
"node_id": "enum_deserialize_test",
"start_at": test_datetime,
"status": status_string,
}
node_state = RouteNodeState.model_validate(dict_data)
assert node_state.status == expected_enum
class TestRuntimeRouteStateSerialization:
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
_NODE1_ID = "node_1"
_ROUTE_STATE1_ID = str(uuid.uuid4())
_NODE2_ID = "node_2"
_ROUTE_STATE2_ID = str(uuid.uuid4())
_NODE3_ID = "node_3"
_ROUTE_STATE3_ID = str(uuid.uuid4())
def _get_runtime_route_state(self):
# Create node states with different configurations
node_state_1 = RouteNodeState(
id=self._ROUTE_STATE1_ID,
node_id=self._NODE1_ID,
start_at=_TEST_DATETIME,
index=1,
)
node_state_2 = RouteNodeState(
id=self._ROUTE_STATE2_ID,
node_id=self._NODE2_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
finished_at=_TEST_DATETIME,
index=2,
)
node_state_3 = RouteNodeState(
id=self._ROUTE_STATE3_ID,
node_id=self._NODE3_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.FAILED,
failed_reason="Test failure",
index=3,
)
runtime_state = RuntimeRouteState(
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
node_state_mapping={
node_state_1.id: node_state_1,
node_state_2.id: node_state_2,
node_state_3.id: node_state_3,
},
)
return runtime_state
def test_runtime_route_state_comprehensive_structure_validation(self):
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
runtime_state = self._get_runtime_route_state()
serialized = runtime_state.model_dump()
# Comprehensive validation of RuntimeRouteState structure
assert "routes" in serialized
assert "node_state_mapping" in serialized
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
# Validate routes dictionary structure and content
assert len(serialized["routes"]) == 2
assert self._ROUTE_STATE1_ID in serialized["routes"]
assert self._ROUTE_STATE2_ID in serialized["routes"]
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
# Validate node_state_mapping dictionary structure and content
assert len(serialized["node_state_mapping"]) == 3
for state_id in [
self._ROUTE_STATE1_ID,
self._ROUTE_STATE2_ID,
self._ROUTE_STATE3_ID,
]:
assert state_id in serialized["node_state_mapping"]
node_data = serialized["node_state_mapping"][state_id]
node_state = runtime_state.node_state_mapping[state_id]
assert node_data["node_id"] == node_state.node_id
assert node_data["status"] == node_state.status
assert node_data["index"] == node_state.index
def test_runtime_route_state_empty_collections(self):
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
runtime_state = RuntimeRouteState()
serialized = runtime_state.model_dump()
# Focus on default empty collection behavior
assert serialized["routes"] == {}
assert serialized["node_state_mapping"] == {}
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
def test_runtime_route_state_json_serialization_structure(self):
"""Test RuntimeRouteState JSON serialization structure."""
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
runtime_state = RuntimeRouteState(
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
)
json_str = runtime_state.model_dump_json()
json_data = json.loads(json_str)
# Focus on JSON structure validation
assert isinstance(json_str, str)
assert isinstance(json_data, dict)
assert "routes" in json_data
assert "node_state_mapping" in json_data
assert json_data["routes"]["source"] == ["target1", "target2"]
assert node_state.id in json_data["node_state_mapping"]
def test_runtime_route_state_deserialization_from_dict(self):
"""Test RuntimeRouteState deserialization from dictionary data."""
node_id = str(uuid.uuid4())
dict_data = {
"routes": {"source_node": ["target_node_1", "target_node_2"]},
"node_state_mapping": {
node_id: {
"id": node_id,
"node_id": "test_node",
"start_at": _TEST_DATETIME,
"status": "running",
"index": 1,
}
},
}
runtime_state = RuntimeRouteState.model_validate(dict_data)
# Focus on deserialization accuracy
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
assert len(runtime_state.node_state_mapping) == 1
assert node_id in runtime_state.node_state_mapping
deserialized_node = runtime_state.node_state_mapping[node_id]
assert deserialized_node.node_id == "test_node"
assert deserialized_node.status == RouteNodeState.Status.RUNNING
assert deserialized_node.index == 1
def test_runtime_route_state_round_trip_consistency(self):
"""Test RuntimeRouteState round-trip serialization consistency."""
original = self._get_runtime_route_state()
# Dictionary round trip
dict_data = original.model_dump(mode="python")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
dict_data = original.model_dump(mode="json")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
# JSON round trip
json_str = original.model_dump_json()
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
assert json_reconstructed == original
class TestSerializationEdgeCases:
"""Test edge cases and error conditions for serialization/deserialization."""
def test_invalid_status_deserialization(self):
"""Test deserialization with invalid status values."""
test_datetime = _TEST_DATETIME
invalid_data = {
"node_id": "invalid_test",
"start_at": test_datetime,
"status": "invalid_status",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "status" in str(exc_info.value)
def test_missing_required_fields_deserialization(self):
"""Test deserialization with missing required fields."""
incomplete_data = {"id": str(uuid.uuid4())}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(incomplete_data)
error_str = str(exc_info.value)
assert "node_id" in error_str or "start_at" in error_str
def test_invalid_datetime_deserialization(self):
"""Test deserialization with invalid datetime values."""
invalid_data = {
"node_id": "datetime_test",
"start_at": "invalid_datetime",
"status": "running",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "start_at" in str(exc_info.value)
def test_invalid_routes_structure_deserialization(self):
"""Test RuntimeRouteState deserialization with invalid routes structure."""
invalid_data = {
"routes": "invalid_routes_structure", # Should be dict
"node_state_mapping": {},
}
with pytest.raises(ValidationError) as exc_info:
RuntimeRouteState.model_validate(invalid_data)
assert "routes" in str(exc_info.value)
def test_timezone_handling_in_datetime_fields(self):
"""Test timezone handling in datetime field serialization."""
utc_datetime = datetime.now(UTC)
naive_datetime = utc_datetime.replace(tzinfo=None)
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
dict_ = node_state.model_dump()
assert dict_["start_at"] == naive_datetime
# Test round trip
reconstructed = RouteNodeState.model_validate(dict_)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None
json = node_state.model_dump_json()
reconstructed = RouteNodeState.model_validate_json(json)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None

View File

@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="what's the weather in SF",
conversation_id="abababa",
),
user_inputs={},
)
@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "hi",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hi",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
system_variables=SystemVariable(
user_id="aaa",
files=[],
),
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

View File

@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -51,7 +51,7 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@ -180,12 +180,12 @@ def test_process():
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="what's the weather in SF",
conversation_id="abababa",
),
user_inputs={},
)

View File

@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.system_variable import SystemVariable
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool()
variable_pool = VariablePool(system_variables=SystemVariable.empty())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
@ -280,7 +281,11 @@ def test_init_headers():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
return Executor(
node_data=node_data,
timeout=timeout,
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
)
executor = create_executor("aa\n cc:")
executor._init_headers()
@ -310,7 +315,11 @@ def test_init_params():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
return Executor(
node_data=node_data,
timeout=timeout,
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
)
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")

View File

@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
),
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch):
),
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)

View File

@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -151,12 +151,12 @@ def test_run():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@ -368,12 +368,12 @@ def test_run_parallel():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@ -808,12 +808,12 @@ def test_iteration_run_error_handle():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)

View File

@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
@ -104,7 +105,7 @@ def graph() -> Graph:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
return GraphRuntimeState(
@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment():
related_id="1",
storage_key="",
)
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], file)
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment():
storage_key="",
),
]
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment():
def test_fetch_files_with_none_segment():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment():
def test_fetch_files_with_array_any_segment():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment():
def test_fetch_files_with_non_existent_variable():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []

View File

@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -53,7 +53,7 @@ def test_execute_answer():
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper:
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="clear",
conversation_id="abababa",
),
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

View File

@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
)
pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@ -157,7 +155,7 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)

View File

@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.system_variable import SystemVariable
from models import UserFrom, WorkflowType
@ -34,7 +35,7 @@ def _create_tool_node():
version="1",
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
node = ToolNode(

View File

@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -68,7 +68,7 @@ def test_overwrite_string_variable():
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -165,7 +165,7 @@ def test_append_variable_to_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -256,7 +256,7 @@ def test_clear_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

View File

@ -5,12 +5,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -109,7 +109,7 @@ def test_remove_first_from_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -196,7 +196,7 @@ def test_remove_last_from_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -275,7 +275,7 @@ def test_remove_first_from_empty_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@ -354,7 +354,7 @@ def test_remove_last_from_empty_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

View File

@ -0,0 +1,251 @@
import json
from typing import Any
import pytest
from pydantic import ValidationError
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.workflow.system_variable import SystemVariable
# Test data constants for SystemVariable serialization tests
VALID_BASE_DATA: dict[str, Any] = {
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
}
COMPLETE_VALID_DATA: dict[str, Any] = {
**VALID_BASE_DATA,
"query": "test query",
"files": [],
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
"dialogue_count": 5,
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
}
def create_test_file() -> File:
"""Create a test File object for serialization tests."""
return File(
tenant_id="test-tenant-id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-file-id",
filename="test.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="test-storage-key",
)
class TestSystemVariableSerialization:
"""Focused tests for SystemVariable serialization/deserialization logic."""
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
assert system_var.query == COMPLETE_VALID_DATA["query"]
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
assert minimal_var.query is None
assert minimal_var.conversation_id is None
assert minimal_var.dialogue_count is None
assert minimal_var.workflow_execution_id is None
assert minimal_var.files == []
def test_alias_handling(self):
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
data_both = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
assert "workflow_run_id" in serialized
assert "workflow_execution_id" not in serialized
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
assert deserialized.query == original.query
assert deserialized.conversation_id == original.conversation_id
assert deserialized.dialogue_count == original.dialogue_count
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert list(deserialized.files) == list(original.files)
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
# Parse JSON and verify structure
json_data = json.loads(json_str)
assert "workflow_run_id" in json_data
assert "workflow_execution_id" not in json_data
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
def test_files_field_deserialization(self):
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
# Test with multiple File objects
file1 = File(
tenant_id="tenant1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="doc1.txt",
storage_key="key1",
)
file2 = File(
tenant_id="tenant2",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
filename="image.jpg",
storage_key="key2",
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
def test_alias_serialization_consistency(self):
"""Test that alias handling works consistently in both serialization directions."""
workflow_id = "test-workflow-id"
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
assert serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
json_serialized = json.loads(system_var.model_dump_json())
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
"""Test the custom model validator behavior for serialization scenarios."""
workflow_id = "test-workflow-execution-id"
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
serialized1 = system_var1.model_dump()
assert "workflow_run_id" in serialized1
assert serialized1["workflow_run_id"] == workflow_id
# Test both present - workflow_run_id takes precedence (validator logic)
data2 = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency
serialized2 = system_var2.model_dump()
assert serialized2["workflow_run_id"] == workflow_id
def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError):
# This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore

View File

@ -1,17 +1,43 @@
import uuid
from collections import defaultdict
import pytest
from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
)
from core.variables.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
def pool():
return VariablePool(system_variables={}, user_inputs={})
return VariablePool(
system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
user_inputs={},
)
@pytest.fixture
@ -52,18 +78,28 @@ def test_use_long_selector(pool):
class TestVariablePool:
def test_constructor(self):
pool = VariablePool()
# Test with minimal required SystemVariable
minimal_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool(system_variables=minimal_system_vars)
# Test with all parameters
pool = VariablePool(
variable_dictionary={},
user_inputs={},
system_variables={},
system_variables=minimal_system_vars,
environment_variables=[],
conversation_variables=[],
)
# Test with more complex SystemVariable
complex_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool(
user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
system_variables=complex_system_vars,
environment_variables=[
segment_to_variable(
segment=build_segment(1),
@ -80,6 +116,323 @@ class TestVariablePool:
],
)
def test_constructor_with_invalid_system_variable_key(self):
with pytest.raises(ValidationError):
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
def test_get_system_variables(self):
sys_var = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_var)
kv = [
("user_id", sys_var.user_id),
("app_id", sys_var.app_id),
("workflow_id", sys_var.workflow_id),
("workflow_run_id", sys_var.workflow_execution_id),
("query", sys_var.query),
("conversation_id", sys_var.conversation_id),
("dialogue_count", sys_var.dialogue_count),
]
for key, expected_value in kv:
segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
assert segment is not None
assert segment.value == expected_value
class TestVariablePoolSerialization:
"""Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
These tests focus exclusively on serialization/deserialization logic to ensure that
VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
while preserving all data integrity.
"""
_NODE1_ID = "node_1"
_NODE2_ID = "node_2"
_NODE3_ID = "node_3"
def _create_pool_without_file(self):
# Create comprehensive system variables
system_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
StringVariable(
id="env_string_id",
name="env_string",
value="env_string_value",
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
),
IntegerVariable(
id="env_integer_id",
name="env_integer",
value=1,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
),
FloatVariable(
id="env_float_id",
name="env_float",
value=1.0,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
),
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
StringVariable(
id="conv_string_id",
name="conv_string",
value="conv_string_value",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
),
IntegerVariable(
id="conv_integer_id",
name="conv_integer",
value=1,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
),
FloatVariable(
id="conv_float_id",
name="conv_float",
value=1.0,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
),
ObjectVariable(
id="conv_object_id",
name="conv_object",
value={"key": "value", "nested": {"data": 123}},
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
),
ArrayStringVariable(
id="conv_array_string_id",
name="conv_array_string",
value=["conv_array_string_value"],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
),
ArrayNumberVariable(
id="conv_array_number_id",
name="conv_array_number",
value=[1, 1.0],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
),
ArrayObjectVariable(
id="conv_array_object_id",
name="conv_array_object",
value=[{"a": 1}, {"b": "2"}],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
),
]
# Create comprehensive user inputs
user_inputs = {
"string_input": "test_value",
"number_input": 42,
"object_input": {"nested": {"key": "value"}},
"array_input": ["item1", "item2", "item3"],
}
# Create VariablePool
pool = VariablePool(
system_variables=system_vars,
user_inputs=user_inputs,
environment_variables=env_vars,
conversation_variables=conv_vars,
)
return pool
def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
test_file = File(
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="test_storage_key",
)
# Add various segment types to variable dictionary
pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
if with_file:
pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
pool.add((self._NODE1_ID, "none_var"), NoneSegment())
# Add array segments including ArrayFileVariable
pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
if with_file:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
def test_system_variables(self):
sys_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_vars)
json = pool.model_dump_json()
pool2 = VariablePool.model_validate_json(json)
assert pool2.system_variables == sys_vars
for mode in ["json", "python"]:
dict_ = pool.model_dump(mode=mode)
pool2 = VariablePool.model_validate(dict_)
assert pool2.system_variables == sys_vars
def test_pool_without_file_vars(self):
pool = self._create_pool_without_file()
json = pool.model_dump_json()
pool2 = pool.model_validate_json(json)
assert pool2.system_variables == pool.system_variables
assert pool2.conversation_variables == pool.conversation_variables
assert pool2.environment_variables == pool.environment_variables
assert pool2.user_inputs == pool.user_inputs
assert pool2.variable_dictionary == pool.variable_dictionary
assert pool2 == pool
def test_basic_dictionary_round_trip(self):
"""Test basic round-trip serialization: model_dump() → model_validate()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to dictionary using Pydantic's model_dump()
serialized_data = original_pool.model_dump()
# Verify serialized data structure
assert isinstance(serialized_data, dict)
assert "system_variables" in serialized_data
assert "user_inputs" in serialized_data
assert "environment_variables" in serialized_data
assert "conversation_variables" in serialized_data
assert "variable_dictionary" in serialized_data
# Deserialize back using Pydantic's model_validate()
reconstructed_pool = VariablePool.model_validate(serialized_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_json_round_trip(self):
"""Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to JSON string using Pydantic's model_dump_json()
json_data = original_pool.model_dump_json()
# Verify JSON is valid string
assert isinstance(json_data, str)
assert len(json_data) > 0
# Deserialize back using Pydantic's model_validate_json()
reconstructed_pool = VariablePool.model_validate_json(json_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_complex_data_serialization(self):
"""Test serialization of complex data structures including ArrayFileVariable"""
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool, with_file=True)
# Test dictionary round-trip
dict_data = original_pool.model_dump()
reconstructed_dict = VariablePool.model_validate(dict_data)
# Test JSON round-trip
json_data = original_pool.model_dump_json()
reconstructed_json = VariablePool.model_validate_json(json_data)
# Verify both reconstructed pools are equivalent
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
# TODO: assert the data for file object...
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
"""Assert that two VariablePools contain equivalent data"""
# Compare system variables
assert pool1.system_variables == pool2.system_variables
# Compare user inputs
assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
# Compare environment variables count
assert pool1.environment_variables == pool2.environment_variables
# Compare conversation variables count
assert pool1.conversation_variables == pool2.conversation_variables
# Test key variable retrievals to ensure functionality is preserved
test_selectors = [
(SYSTEM_VARIABLE_NODE_ID, "user_id"),
(SYSTEM_VARIABLE_NODE_ID, "app_id"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
(CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
(self._NODE1_ID, "string_var"),
(self._NODE1_ID, "int_var"),
(self._NODE1_ID, "float_var"),
(self._NODE2_ID, "array_string"),
(self._NODE2_ID, "array_number"),
(self._NODE3_ID, "nested", "deep", "var"),
]
for selector in test_selectors:
val1 = pool1.get(selector)
val2 = pool2.get(selector)
# Both should exist or both should be None
assert (val1 is None) == (val2 is None)
if val1 is not None and val2 is not None:
# Values should be equal
assert val1.value == val2.value
# Value types should be the same (more important than exact class type)
assert val1.value_type == val2.value_type
def test_variable_pool_deserialization_default_dict(self):
variable_pool = VariablePool(
user_inputs={"a": 1, "b": "2"},
system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
environment_variables=[
StringVariable(name="str_var", value="a"),
],
conversation_variables=[IntegerVariable(name="int_var", value=1)],
)
assert isinstance(variable_pool.variable_dictionary, defaultdict)
json = variable_pool.model_dump_json()
loaded = VariablePool.model_validate_json(json)
assert isinstance(loaded.variable_dictionary, defaultdict)
loaded.add(["non_exist_node", "a"], 1)
pool_dict = variable_pool.model_dump()
loaded = VariablePool.model_validate(pool_dict)
assert isinstance(loaded.variable_dictionary, defaultdict)
loaded.add(["non_exist_node", "a"], 1)

View File

@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
@ -67,14 +67,14 @@ def real_app_generate_entity():
@pytest.fixture
def real_workflow_system_variables():
return {
SystemVariableKey.QUERY: "test query",
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
SystemVariableKey.USER_ID: "test-user-id",
SystemVariableKey.APP_ID: "test-app-id",
SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
}
return SystemVariable(
query="test query",
conversation_id="test-conversation-id",
user_id="test-user-id",
app_id="test-app-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-run-id",
)
@pytest.fixture

View File

@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}

View File

@ -505,8 +505,8 @@ def test_build_segment_type_for_scalar():
size=1000,
)
cases = [
TestCase(0, SegmentType.NUMBER),
TestCase(0.0, SegmentType.NUMBER),
TestCase(0, SegmentType.INTEGER),
TestCase(0.0, SegmentType.FLOAT),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
@ -531,14 +531,14 @@ class TestBuildSegmentWithType:
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
assert result.value_type == SegmentType.NUMBER
assert result.value_type == SegmentType.INTEGER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
assert result.value_type == SegmentType.NUMBER
assert result.value_type == SegmentType.FLOAT
def test_object_type(self):
"""Test building an object segment with correct type."""
@ -652,14 +652,14 @@ class TestBuildSegmentWithType:
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
assert "Expected string, but got None" in str(exc_info.value)
assert "expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
assert "Expected string, but got empty list" in str(exc_info.value)
assert "expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
@ -674,19 +674,19 @@ class TestBuildSegmentWithType:
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
assert result_int.value_type == SegmentType.NUMBER
assert result_int.value_type == SegmentType.INTEGER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
assert result_float.value_type == SegmentType.NUMBER
assert result_float.value_type == SegmentType.FLOAT
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
(SegmentType.NUMBER, 42, IntegerSegment),
(SegmentType.NUMBER, 3.14, FloatSegment),
(SegmentType.INTEGER, 42, IntegerSegment),
(SegmentType.FLOAT, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
@ -857,5 +857,5 @@ class TestBuildSegmentValueErrors:
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.NUMBER
assert false_segment.value_type == SegmentType.NUMBER
assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.INTEGER

View File

@ -0,0 +1,351 @@
import struct
import time
import uuid
from unittest import mock
import pytest
from hypothesis import given
from hypothesis import strategies as st
from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp
# Tests for private helper function _create_uuidv7_bytes
def test_create_uuidv7_bytes_basic_structure():
"""Test basic byte structure creation."""
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Should be exactly 16 bytes
assert len(result) == 16
assert isinstance(result, bytes)
# Create UUID from bytes to verify it's valid
uuid_obj = uuid.UUID(bytes=result)
assert uuid_obj.version == 7
def test_create_uuidv7_bytes_timestamp_encoding():
"""Test timestamp is correctly encoded in first 48 bits."""
timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Extract timestamp from first 6 bytes
timestamp_bytes = b"\x00\x00" + result[0:6]
extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0]
assert extracted_timestamp == timestamp_ms
def test_create_uuidv7_bytes_version_bits():
"""Test version bits are set to 7."""
timestamp_ms = 1609459200000
random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Extract version from bytes 6-7
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
version = (version_and_rand_a >> 12) & 0x0F
assert version == 7
def test_create_uuidv7_bytes_variant_bits():
"""Test variant bits are set correctly."""
timestamp_ms = 1609459200000
random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Check variant bits in byte 8 (should be 10xxxxxx)
variant_byte = result[8]
variant_bits = (variant_byte >> 6) & 0b11
assert variant_bits == 0b10 # Should be binary 10
def test_create_uuidv7_bytes_random_data():
"""Test random bytes are placed correctly."""
timestamp_ms = 1609459200000
random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
# Check random data A (12 bits from bytes 6-7, excluding version)
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
rand_a = version_and_rand_a & 0x0FFF
expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF
assert rand_a == expected_rand_a
# Check random data B (bytes 8-15, with variant bits preserved)
# Byte 8 should have variant bits set but preserve lower 6 bits
expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80
assert result[8] == expected_byte_8
# Bytes 9-15 should match random_bytes[3:10]
assert result[9:16] == random_bytes[3:10]
def test_create_uuidv7_bytes_zero_random():
"""Test with zero random bytes (boundary case)."""
timestamp_ms = 1609459200000
zero_random_bytes = b"\x00" * 10
result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
# Should still be valid UUIDv7
uuid_obj = uuid.UUID(bytes=result)
assert uuid_obj.version == 7
# Version bits should be 0x7000
version_and_rand_a = struct.unpack(">H", result[6:8])[0]
assert version_and_rand_a == 0x7000
# Variant byte should be 0x80 (variant bits + zero random bits)
assert result[8] == 0x80
# Remaining bytes should be zero
assert result[9:16] == b"\x00" * 7
def test_uuidv7_basic_generation():
"""Test basic UUID generation produces valid UUIDv7."""
result = uuidv7()
# Should be a UUID object
assert isinstance(result, uuid.UUID)
# Should be version 7
assert result.version == 7
# Should have correct variant (RFC 4122 variant)
# Variant bits should be 10xxxxxx (0x80-0xBF range)
variant_byte = result.bytes[8]
assert (variant_byte >> 6) == 0b10
def test_uuidv7_with_custom_timestamp():
"""Test UUID generation with custom timestamp."""
custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
result = uuidv7(custom_timestamp)
assert isinstance(result, uuid.UUID)
assert result.version == 7
# Extract and verify timestamp
extracted_timestamp = uuidv7_timestamp(result)
assert isinstance(extracted_timestamp, int)
assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds
def test_uuidv7_with_none_timestamp(monkeypatch):
"""Test UUID generation with None timestamp uses current time."""
mock_time = 1609459200
mock_time_func = mock.Mock(return_value=mock_time)
monkeypatch.setattr("time.time", mock_time_func)
result = uuidv7(None)
assert isinstance(result, uuid.UUID)
assert result.version == 7
# Should use the mocked current time (converted to milliseconds)
assert mock_time_func.called
extracted_timestamp = uuidv7_timestamp(result)
assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000
def test_uuidv7_time_ordering():
"""Test that sequential UUIDs have increasing timestamps."""
# Generate UUIDs with incrementing timestamps (in milliseconds)
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
uuid1 = uuidv7(timestamp1)
uuid2 = uuidv7(timestamp2)
uuid3 = uuidv7(timestamp3)
# Extract timestamps
ts1 = uuidv7_timestamp(uuid1)
ts2 = uuidv7_timestamp(uuid2)
ts3 = uuidv7_timestamp(uuid3)
# Should be in ascending order
assert ts1 < ts2 < ts3
# UUIDs should be lexicographically ordered by their string representation
# due to time-ordering property of UUIDv7
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
assert uuid_strings == sorted(uuid_strings)
def test_uuidv7_uniqueness():
"""Test that multiple calls generate different UUIDs."""
# Generate multiple UUIDs with the same timestamp (in milliseconds)
timestamp = 1609459200000
uuids = [uuidv7(timestamp) for _ in range(100)]
# All should be unique despite same timestamp (due to random bits)
assert len(set(uuids)) == 100
# All should have the same extracted timestamp
for uuid_obj in uuids:
extracted_ts = uuidv7_timestamp(uuid_obj)
assert extracted_ts == timestamp
def test_uuidv7_timestamp_error_handling_wrong_version():
"""Test error handling for non-UUIDv7 inputs."""
uuid_v4 = uuid.uuid4()
with pytest.raises(ValueError) as exc_ctx:
uuidv7_timestamp(uuid_v4)
assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value)
assert f"got version {uuid_v4.version}" in str(exc_ctx.value)
@given(st.integers(max_value=2**48 - 1, min_value=0))
def test_uuidv7_timestamp_round_trip(timestamp_ms):
# Generate UUID with timestamp
uuid_obj = uuidv7(timestamp_ms)
# Extract timestamp back
extracted_timestamp = uuidv7_timestamp(uuid_obj)
# Should match exactly for integer millisecond timestamps
assert extracted_timestamp == timestamp_ms
def test_uuidv7_timestamp_edge_cases():
"""Test timestamp extraction with edge case values."""
# Test with very small timestamp
small_timestamp = 1 # 1ms after epoch
uuid_small = uuidv7(small_timestamp)
extracted_small = uuidv7_timestamp(uuid_small)
assert extracted_small == small_timestamp
# Test with large timestamp (year 2038+)
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
uuid_large = uuidv7(large_timestamp)
extracted_large = uuidv7_timestamp(uuid_large)
assert extracted_large == large_timestamp
def test_uuidv7_boundary_basic_generation():
"""Test basic boundary UUID generation with a known timestamp."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
result = uuidv7_boundary(timestamp)
# Should be a UUID object
assert isinstance(result, uuid.UUID)
# Should be version 7
assert result.version == 7
# Should have correct variant (RFC 4122 variant)
# Variant bits should be 10xxxxxx (0x80-0xBF range)
variant_byte = result.bytes[8]
assert (variant_byte >> 6) == 0b10
def test_uuidv7_boundary_timestamp_extraction():
"""Test that boundary UUID timestamp can be extracted correctly."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
boundary_uuid = uuidv7_boundary(timestamp)
# Extract timestamp using existing function
extracted_timestamp = uuidv7_timestamp(boundary_uuid)
# Should match exactly
assert extracted_timestamp == timestamp
def test_uuidv7_boundary_deterministic():
"""Test that boundary UUIDs are deterministic for same timestamp."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
# Generate multiple boundary UUIDs with same timestamp
uuid1 = uuidv7_boundary(timestamp)
uuid2 = uuidv7_boundary(timestamp)
uuid3 = uuidv7_boundary(timestamp)
# Should all be identical
assert uuid1 == uuid2 == uuid3
assert str(uuid1) == str(uuid2) == str(uuid3)
def test_uuidv7_boundary_is_minimum():
"""Test that boundary UUID is lexicographically smaller than regular UUIDs."""
timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
# Generate boundary UUID
boundary_uuid = uuidv7_boundary(timestamp)
# Generate multiple regular UUIDs with same timestamp
regular_uuids = [uuidv7(timestamp) for _ in range(50)]
# Boundary UUID should be lexicographically smaller than all regular UUIDs
boundary_str = str(boundary_uuid)
for regular_uuid in regular_uuids:
regular_str = str(regular_uuid)
assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}"
# Also test with bytes comparison
boundary_bytes = boundary_uuid.bytes
for regular_uuid in regular_uuids:
regular_bytes = regular_uuid.bytes
assert boundary_bytes < regular_bytes
def test_uuidv7_boundary_different_timestamps():
"""Test that boundary UUIDs with different timestamps are ordered correctly."""
timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
uuid1 = uuidv7_boundary(timestamp1)
uuid2 = uuidv7_boundary(timestamp2)
uuid3 = uuidv7_boundary(timestamp3)
# Extract timestamps to verify
ts1 = uuidv7_timestamp(uuid1)
ts2 = uuidv7_timestamp(uuid2)
ts3 = uuidv7_timestamp(uuid3)
# Should be in ascending order
assert ts1 < ts2 < ts3
# UUIDs should be lexicographically ordered
uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
assert uuid_strings == sorted(uuid_strings)
# Bytes should also be ordered
assert uuid1.bytes < uuid2.bytes < uuid3.bytes
def test_uuidv7_boundary_edge_cases():
"""Test boundary UUID generation with edge case timestamp values."""
# Test with timestamp 0 (Unix epoch)
epoch_uuid = uuidv7_boundary(0)
assert isinstance(epoch_uuid, uuid.UUID)
assert epoch_uuid.version == 7
assert uuidv7_timestamp(epoch_uuid) == 0
# Test with very large timestamp values
large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
large_uuid = uuidv7_boundary(large_timestamp)
assert isinstance(large_uuid, uuid.UUID)
assert large_uuid.version == 7
assert uuidv7_timestamp(large_uuid) == large_timestamp
# Test with current time
current_time = int(time.time() * 1000)
current_uuid = uuidv7_boundary(current_time)
assert isinstance(current_uuid, uuid.UUID)
assert current_uuid.version == 7
assert uuidv7_timestamp(current_uuid) == current_time

View File

@ -214,6 +214,10 @@ SQLALCHEMY_POOL_SIZE=30
SQLALCHEMY_POOL_RECYCLE=3600
# Whether to print SQL, default is false.
SQLALCHEMY_ECHO=false
# If True, will test connections for liveness upon each checkout
SQLALCHEMY_POOL_PRE_PING=false
# Whether to enable the Last in first out option or use default FIFO queue if is false
SQLALCHEMY_POOL_USE_LIFO=false
# Maximum number of connections to the database
# Default is 100

View File

@ -56,6 +56,8 @@ x-shared-env: &shared-api-worker-env
SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30}
SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600}
SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false}
SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false}
SQLALCHEMY_POOL_USE_LIFO: ${SQLALCHEMY_POOL_USE_LIFO:-false}
POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100}
POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB}
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}

View File

@ -1,5 +1,3 @@
'use client'
import WorkflowApp from '@/app/components/workflow-app'
const Page = () => {

View File

@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 4V8M8 8V12M8 8H12M8 8H4" stroke="#6B7280" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 206 B

View File

@ -1,4 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.631586 8.25C0.631586 6.46656 2.04586 5 3.8158 5C5.58573 5 7.00001 6.46656 7.00001 8.25C7.00001 10.0334 5.58573 11.5 3.8158 11.5C3.45197 11.5 3.10149 11.4375 2.77474 11.3222C2.72073 11.3031 2.68723 11.2913 2.66266 11.2832C2.65821 11.2817 2.65456 11.2806 2.65164 11.2796L2.64892 11.2799C2.63177 11.2818 2.60839 11.285 2.56507 11.2909L1.06766 11.4954C0.905637 11.5175 0.743029 11.459 0.632239 11.3387C0.521449 11.2185 0.476481 11.0516 0.511825 10.8919L0.817497 9.51109C0.828118 9.46311 0.833802 9.43722 0.837453 9.41817C0.83766 9.4171 0.838022 9.41517 0.838022 9.41517C0.837114 9.412 0.835963 9.40808 0.834525 9.40332C0.826292 9.37605 0.814183 9.33888 0.794499 9.27863C0.688657 8.95463 0.631586 8.60857 0.631586 8.25Z" fill="#98A2B3"/>
<path d="M2.57377 4.1863C2.96256 4.06535 3.37698 4 3.80894 4C6.16566 4 8.00006 5.94534 8.00006 8.24999C8.00006 8.65682 7.9429 9.05245 7.8358 9.42816C8.10681 9.37948 8.36964 9.30678 8.6219 9.21229C8.65748 9.19897 8.69298 9.18534 8.72893 9.17304C8.75795 9.17641 8.78684 9.18093 8.81574 9.18517L10.4222 9.42065C10.498 9.43179 10.5841 9.44444 10.6591 9.4487C10.7422 9.45343 10.8713 9.45292 11.0081 9.39408C11.1789 9.32061 11.3164 9.18628 11.3938 9.01716C11.4558 8.88174 11.4593 8.75269 11.4564 8.66955C11.4539 8.59442 11.4433 8.5081 11.4339 8.43202L11.2309 6.78307C11.2256 6.7402 11.2229 6.71768 11.2213 6.70118C11.23 6.66505 11.2466 6.6301 11.2598 6.59546C11.4492 6.09896 11.5526 5.56093 11.5526 5C11.5526 2.51163 9.52304 0.5 7.02632 0.5C4.80843 0.5 2.95915 2.08742 2.57377 4.1863Z" fill="#98A2B3"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -1,3 +0,0 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M14.1667 6.66634H15.8333C16.2754 6.66634 16.6993 6.84194 17.0118 7.1545C17.3244 7.46706 17.5 7.89098 17.5 8.33301V13.333C17.5 13.775 17.3244 14.199 17.0118 14.5115C16.6993 14.8241 16.2754 14.9997 15.8333 14.9997H14.1667V18.333L10.8333 14.9997H7.5C7.28111 14.9999 7.06433 14.9569 6.86211 14.8731C6.6599 14.7893 6.47623 14.6663 6.32167 14.5113M6.32167 14.5113L9.16667 11.6663H12.5C12.942 11.6663 13.366 11.4907 13.6785 11.1782C13.9911 10.8656 14.1667 10.4417 14.1667 9.99967V4.99967C14.1667 4.55765 13.9911 4.13372 13.6785 3.82116C13.366 3.5086 12.942 3.33301 12.5 3.33301H4.16667C3.72464 3.33301 3.30072 3.5086 2.98816 3.82116C2.67559 4.13372 2.5 4.55765 2.5 4.99967V9.99967C2.5 10.4417 2.67559 10.8656 2.98816 11.1782C3.30072 11.4907 3.72464 11.6663 4.16667 11.6663H5.83333V14.9997L6.32167 14.5113Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1002 B

View File

@ -1,4 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.5 1.00779C6.5 0.994638 6.5 0.988062 6.49943 0.976137C6.48764 0.729248 6.27052 0.51224 6.02363 0.50056C6.01171 0.499996 6.0078 0.499998 6.00001 0.5H4.37933C3.97686 0.499995 3.64468 0.49999 3.37409 0.522098C3.09304 0.545061 2.83469 0.594343 2.59202 0.717989C2.2157 0.909735 1.90973 1.2157 1.71799 1.59202C1.59434 1.83469 1.54506 2.09304 1.5221 2.37409C1.49999 2.64468 1.49999 2.97686 1.5 3.37934V8.62066C1.49999 9.02313 1.49999 9.35532 1.5221 9.62591C1.54506 9.90696 1.59434 10.1653 1.71799 10.408C1.90973 10.7843 2.2157 11.0903 2.59202 11.282C2.83469 11.4057 3.09304 11.4549 3.37409 11.4779C3.64468 11.5 3.97686 11.5 4.37934 11.5H7.62066C8.02314 11.5 8.35532 11.5 8.62591 11.4779C8.90696 11.4549 9.16531 11.4057 9.40798 11.282C9.78431 11.0903 10.0903 10.7843 10.282 10.408C10.4057 10.1653 10.4549 9.90696 10.4779 9.62591C10.5 9.35532 10.5 9.02314 10.5 8.62066V4.99997C10.5 4.9922 10.5 4.98832 10.4994 4.97641C10.4878 4.72949 10.2707 4.51236 10.0238 4.50057C10.0119 4.50001 10.0054 4.50001 9.99225 4.50001L7.78404 4.50001C7.65786 4.50002 7.53496 4.50004 7.43089 4.49153C7.31659 4.48219 7.18172 4.46016 7.04601 4.39101C6.85785 4.29514 6.70487 4.14216 6.609 3.954C6.53985 3.81828 6.51781 3.68342 6.50848 3.56912C6.49997 3.46504 6.49999 3.34215 6.5 3.21596L6.5 1.00779ZM4 6.5C3.72386 6.5 3.5 6.72386 3.5 7C3.5 7.27614 3.72386 7.5 4 7.5H8C8.27614 7.5 8.5 7.27614 8.5 7C8.5 6.72386 8.27614 6.5 8 6.5H4ZM4 8.5C3.72386 8.5 3.5 8.72386 3.5 9C3.5 9.27614 3.72386 9.5 4 9.5H7C7.27614 9.5 7.5 9.27614 7.5 9C7.5 8.72386 7.27614 8.5 7 8.5H4Z" fill="#98A2B3"/>
<path d="M9.45398 3.5C9.60079 3.5 9.67419 3.5 9.73432 3.46314C9.81925 3.41107 9.87002 3.28842 9.84674 3.19157C9.83025 3.12299 9.78238 3.07516 9.68665 2.97952L8.02049 1.31336C7.92484 1.21762 7.87701 1.16975 7.80843 1.15326C7.71158 1.12998 7.58893 1.18075 7.53687 1.26567C7.5 1.3258 7.5 1.39921 7.5 1.54602L7.5 3.09998C7.5 3.23999 7.5 3.30999 7.52725 3.36347C7.55122 3.41051 7.58946 3.44876 7.6365 3.47272C7.68998 3.49997 7.75998 3.49997 7.9 3.49998L9.45398 3.5Z" fill="#98A2B3"/>
</svg>

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -1,3 +0,0 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M16.25 11.875V9.6875C16.25 8.1342 14.9908 6.875 13.4375 6.875H12.1875C11.6697 6.875 11.25 6.45527 11.25 5.9375V4.6875C11.25 3.1342 9.9908 1.875 8.4375 1.875H6.875M6.875 12.5H13.125M6.875 15H10M8.75 1.875H4.6875C4.16973 1.875 3.75 2.29473 3.75 2.8125V17.1875C3.75 17.7053 4.16973 18.125 4.6875 18.125H15.3125C15.8303 18.125 16.25 17.7053 16.25 17.1875V9.375C16.25 5.23286 12.8921 1.875 8.75 1.875Z" stroke="#1F2A37" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 595 B

View File

@ -1,3 +0,0 @@
<svg width="26" height="26" viewBox="0 0 26 26" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M22.0101 4.50191C20.3529 3.74154 18.5759 3.18133 16.7179 2.86048C16.6841 2.85428 16.6503 2.86976 16.6328 2.90071C16.4043 3.30719 16.1511 3.83748 15.9738 4.25429C13.9754 3.95511 11.9873 3.95511 10.0298 4.25429C9.85253 3.82822 9.59019 3.30719 9.36062 2.90071C9.34319 2.87079 9.30939 2.85532 9.27555 2.86048C7.41857 3.18031 5.64152 3.74051 3.98335 4.50191C3.96899 4.5081 3.95669 4.51843 3.94852 4.53183C0.577841 9.56755 -0.345529 14.4795 0.107445 19.3306C0.109495 19.3543 0.122817 19.377 0.141265 19.3914C2.36514 21.0246 4.51935 22.0161 6.63355 22.6732C6.66739 22.6836 6.70324 22.6712 6.72477 22.6433C7.22489 21.9604 7.6707 21.2402 8.05293 20.4829C8.07549 20.4386 8.05396 20.386 8.00785 20.3684C7.30073 20.1002 6.6274 19.7731 5.97971 19.4017C5.92848 19.3718 5.92437 19.2985 5.9715 19.2635C6.1078 19.1613 6.24414 19.0551 6.37428 18.9478C6.39783 18.9282 6.43064 18.924 6.45833 18.9364C10.7134 20.8791 15.32 20.8791 19.5249 18.9364C19.5525 18.923 19.5854 18.9272 19.6099 18.9467C19.7401 19.054 19.8764 19.1613 20.0137 19.2635C20.0609 19.2985 20.0578 19.3718 20.0066 19.4017C19.3589 19.7804 18.6855 20.1002 17.9774 20.3674C17.9313 20.3849 17.9108 20.4386 17.9333 20.4829C18.3238 21.2392 18.7696 21.9593 19.2605 22.6423C19.281 22.6712 19.3179 22.6836 19.3517 22.6732C21.4761 22.0161 23.6303 21.0246 25.8542 19.3914C25.8737 19.377 25.886 19.3553 25.8881 19.3316C26.4302 13.7232 24.98 8.85156 22.0439 4.53286C22.0367 4.51843 22.0245 4.5081 22.0101 4.50191ZM8.68836 16.3768C7.40729 16.3768 6.35173 15.2007 6.35173 13.7563C6.35173 12.3119 7.38682 11.1358 8.68836 11.1358C10.0001 11.1358 11.0455 12.3222 11.025 13.7563C11.025 15.2007 9.98986 16.3768 8.68836 16.3768ZM17.3276 16.3768C16.0466 16.3768 14.991 15.2007 14.991 13.7563C14.991 12.3119 16.0261 11.1358 17.3276 11.1358C18.6394 11.1358 19.6847 12.3222 19.6643 13.7563C19.6643 15.2007 18.6394 16.3768 17.3276 16.3768Z" fill="#5865F2"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -1,17 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_131_1011)">
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.0003 0.5C9.15149 0.501478 6.39613 1.51046 4.22687 3.34652C2.05761 5.18259 0.615903 7.72601 0.159545 10.522C-0.296814 13.318 0.261927 16.1842 1.73587 18.6082C3.20981 21.0321 5.50284 22.8558 8.20493 23.753C8.80105 23.8636 9.0256 23.4941 9.0256 23.18C9.0256 22.8658 9.01367 21.955 9.0097 20.9592C5.6714 21.6804 4.96599 19.5505 4.96599 19.5505C4.42152 18.1674 3.63464 17.8039 3.63464 17.8039C2.54571 17.065 3.71611 17.0788 3.71611 17.0788C4.92227 17.1637 5.55616 18.3097 5.55616 18.3097C6.62521 20.1333 8.36389 19.6058 9.04745 19.2976C9.15475 18.5251 9.46673 17.9995 9.8105 17.7012C7.14383 17.4008 4.34204 16.3774 4.34204 11.8054C4.32551 10.6197 4.76802 9.47305 5.57801 8.60268C5.45481 8.30236 5.04348 7.08923 5.69524 5.44143C5.69524 5.44143 6.7027 5.12135 8.9958 6.66444C10.9627 6.12962 13.0379 6.12962 15.0047 6.66444C17.2958 5.12135 18.3013 5.44143 18.3013 5.44143C18.9551 7.08528 18.5437 8.29841 18.4205 8.60268C19.2331 9.47319 19.6765 10.6218 19.6585 11.8094C19.6585 16.3912 16.8507 17.4008 14.1801 17.6952C14.6093 18.0667 14.9928 18.7918 14.9928 19.9061C14.9928 21.5026 14.9789 22.7868 14.9789 23.18C14.9789 23.4981 15.1955 23.8695 15.8035 23.753C18.5059 22.8557 20.7992 21.0317 22.2731 18.6073C23.747 16.183 24.3055 13.3163 23.8486 10.5201C23.3917 7.7238 21.9493 5.18035 19.7793 3.34461C17.6093 1.50886 14.8533 0.500541 12.0042 0.5H12.0003Z" fill="#191717"/>
<path d="M4.54444 17.6321C4.5186 17.6914 4.42322 17.7092 4.34573 17.6677C4.26823 17.6262 4.21061 17.5491 4.23843 17.4879C4.26625 17.4266 4.35964 17.4108 4.43714 17.4523C4.51463 17.4938 4.57424 17.5729 4.54444 17.6321Z" fill="#191717"/>
<path d="M5.03123 18.1714C4.99008 18.192 4.943 18.1978 4.89805 18.1877C4.8531 18.1776 4.81308 18.1523 4.78483 18.1161C4.70734 18.0331 4.69143 17.9185 4.75104 17.8671C4.81066 17.8157 4.91797 17.8395 4.99546 17.9224C5.07296 18.0054 5.09084 18.12 5.03123 18.1714Z" fill="#191717"/>
<path d="M5.50425 18.857C5.43072 18.9084 5.30553 18.857 5.23598 18.7543C5.21675 18.7359 5.20146 18.7138 5.19101 18.6893C5.18056 18.6649 5.17517 18.6386 5.17517 18.612C5.17517 18.5855 5.18056 18.5592 5.19101 18.5347C5.20146 18.5103 5.21675 18.4882 5.23598 18.4698C5.3095 18.4204 5.4347 18.4698 5.50425 18.5705C5.57379 18.6713 5.57578 18.8057 5.50425 18.857V18.857Z" fill="#191717"/>
<path d="M6.14612 19.5207C6.08054 19.5939 5.94741 19.5741 5.83812 19.4753C5.72883 19.3765 5.70299 19.2422 5.76857 19.171C5.83414 19.0999 5.96727 19.1197 6.08054 19.2165C6.1938 19.3133 6.21566 19.4496 6.14612 19.5207V19.5207Z" fill="#191717"/>
<path d="M7.04617 19.9081C7.01637 20.001 6.88124 20.0425 6.74612 20.003C6.611 19.9635 6.52158 19.8528 6.54741 19.758C6.57325 19.6631 6.71036 19.6197 6.84747 19.6631C6.98457 19.7066 7.07201 19.8113 7.04617 19.9081Z" fill="#191717"/>
<path d="M8.02783 19.9752C8.02783 20.072 7.91656 20.155 7.77349 20.1569C7.63042 20.1589 7.51318 20.0799 7.51318 19.9831C7.51318 19.8863 7.62445 19.8033 7.76752 19.8013C7.91059 19.7993 8.02783 19.8764 8.02783 19.9752Z" fill="#191717"/>
<path d="M8.9419 19.8232C8.95978 19.92 8.86042 20.0207 8.71735 20.0445C8.57428 20.0682 8.4491 20.0109 8.43121 19.916C8.41333 19.8212 8.51666 19.7185 8.65576 19.6928C8.79485 19.6671 8.92401 19.7264 8.9419 19.8232Z" fill="#191717"/>
</g>
<defs>
<clipPath id="clip0_131_1011">
<rect width="24" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -1,3 +0,0 @@
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.41663 3.75033H3.24996C2.96264 3.75033 2.68709 3.86446 2.48393 4.06763C2.28076 4.27079 2.16663 4.54634 2.16663 4.83366V10.2503C2.16663 10.5376 2.28076 10.8132 2.48393 11.0164C2.68709 11.2195 2.96264 11.3337 3.24996 11.3337H8.66663C8.95394 11.3337 9.22949 11.2195 9.43266 11.0164C9.63582 10.8132 9.74996 10.5376 9.74996 10.2503V8.08366M7.58329 2.66699H10.8333M10.8333 2.66699V5.91699M10.8333 2.66699L5.41663 8.08366" stroke="#9CA3AF" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 596 B

View File

@ -1,3 +0,0 @@
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.41663 3.75008H3.24996C2.96264 3.75008 2.68709 3.86422 2.48393 4.06738C2.28076 4.27055 2.16663 4.5461 2.16663 4.83341V10.2501C2.16663 10.5374 2.28076 10.8129 2.48393 11.0161C2.68709 11.2193 2.96264 11.3334 3.24996 11.3334H8.66663C8.95394 11.3334 9.22949 11.2193 9.43266 11.0161C9.63582 10.8129 9.74996 10.5374 9.74996 10.2501V8.08341M7.58329 2.66675H10.8333M10.8333 2.66675V5.91675M10.8333 2.66675L5.41663 8.08341" stroke="#1C64F2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 595 B

View File

@ -1,3 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 2.5L10.5 6M10.5 6L7 9.5M10.5 6H1.5" stroke="#1C64F2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 217 B

View File

@ -1,12 +0,0 @@
'use client'
import useDocumentTitle from '@/hooks/use-document-title'
import { useTranslation } from 'react-i18next'
export default function DatasetsLayout({ children }: { children: React.ReactNode }) {
const { t } = useTranslation()
useDocumentTitle(t('common.menus.apps'))
return (<>
{children}
</>)
}

View File

@ -1,32 +1,8 @@
'use client'
import { useTranslation } from 'react-i18next'
import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
import Link from 'next/link'
import style from '../list.module.css'
import Apps from './Apps'
import { useEducationInit } from '@/app/education-apply/hooks'
import { useGlobalPublicStore } from '@/context/global-public-context'
import Apps from '@/app/components/apps'
const AppList = () => {
const { t } = useTranslation()
useEducationInit()
const { systemFeatures } = useGlobalPublicStore()
return (
<div className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
<Apps />
{!systemFeatures.branding.enabled && <footer className='shrink-0 grow-0 px-12 py-6'>
<h3 className='text-gradient text-xl font-semibold leading-tight'>{t('app.join')}</h3>
<p className='system-sm-regular mt-1 text-text-tertiary'>{t('app.communityIntro')}</p>
<div className='mt-3 flex items-center gap-2'>
<Link className={style.socialMediaLink} target='_blank' rel='noopener noreferrer' href='https://github.com/langgenius/dify'>
<RiGithubFill className='h-5 w-5 text-text-tertiary' />
</Link>
<Link className={style.socialMediaLink} target='_blank' rel='noopener noreferrer' href='https://discord.gg/FngNHpbcY7'>
<RiDiscordFill className='h-5 w-5 text-text-tertiary' />
</Link>
</div>
</footer>}
</div >
<Apps />
)
}

View File

@ -62,7 +62,6 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
<Tooltip
position='right'
noDecoration
needsDelay
popupContent={
<LinkedAppsPanel
relatedApps={relatedApps.data}
@ -87,7 +86,6 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
<Tooltip
position='right'
noDecoration
needsDelay
popupContent={
<div className='w-[240px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-4'>
<div className='inline-flex rounded-lg border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle p-2'>

View File

@ -1,217 +0,0 @@
.listItem {
@apply col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-xs min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg;
}
.listItem.newItemCard {
@apply outline outline-1 outline-gray-200 -outline-offset-1 hover:shadow-sm hover:bg-white;
background-color: rgba(229, 231, 235, 0.5);
}
.listItem.selectable {
@apply relative bg-gray-50 outline outline-1 outline-gray-200 -outline-offset-1 shadow-none hover:bg-none hover:shadow-none hover:outline-primary-200 transition-colors;
}
.listItem.selectable * {
@apply relative;
}
.listItem.selectable::before {
content: "";
@apply absolute top-0 left-0 block w-full h-full rounded-lg pointer-events-none opacity-0 transition-opacity duration-200 ease-in-out hover:opacity-100;
background: linear-gradient(0deg,
rgba(235, 245, 255, 0.5),
rgba(235, 245, 255, 0.5)),
#ffffff;
}
.listItem.selectable:hover::before {
@apply opacity-100;
}
.listItem.selected {
@apply border-primary-600 hover:border-primary-600 border-2;
}
.listItem.selected::before {
@apply opacity-100;
}
.appIcon {
@apply flex items-center justify-center w-8 h-8 bg-pink-100 rounded-lg grow-0 shrink-0;
}
.appIcon.medium {
@apply w-9 h-9;
}
.appIcon.large {
@apply w-10 h-10;
}
.newItemIcon {
@apply flex items-center justify-center w-8 h-8 transition-colors duration-200 ease-in-out border border-gray-200 rounded-lg hover:bg-white grow-0 shrink-0;
}
.listItem:hover .newItemIcon {
@apply bg-gray-50 border-primary-100;
}
.newItemCard .newItemIcon {
@apply bg-gray-100;
}
.newItemCard:hover .newItemIcon {
@apply bg-white;
}
.selectable .newItemIcon {
@apply bg-gray-50;
}
.selectable:hover .newItemIcon {
@apply bg-primary-50;
}
.newItemIconImage {
@apply grow-0 shrink-0 block w-4 h-4 bg-center bg-contain transition-colors duration-200 ease-in-out;
color: #1f2a37;
}
.listItem:hover .newIconImage {
@apply text-primary-600;
}
.newItemIconAdd {
background-image: url("./apps/assets/add.svg");
}
/* .newItemIconChat {
background-image: url("~@/app/components/base/icons/assets/public/header-nav/studio/Robot.svg");
}
.selected .newItemIconChat {
background-image: url("~@/app/components/base/icons/assets/public/header-nav/studio/Robot-Active.svg");
} */
.newItemIconComplete {
background-image: url("./apps/assets/completion.svg");
}
.listItemTitle {
@apply flex pt-[14px] px-[14px] pb-3 h-[66px] items-center gap-3 grow-0 shrink-0;
}
.listItemHeading {
@apply relative h-8 text-sm font-medium leading-8 grow;
}
.listItemHeadingContent {
@apply absolute top-0 left-0 w-full h-full overflow-hidden text-ellipsis whitespace-nowrap;
}
.actionIconWrapper {
@apply hidden h-8 w-8 p-2 rounded-md border-none hover:bg-gray-100 !important;
}
.listItem:hover .actionIconWrapper {
@apply !inline-flex;
}
.deleteDatasetIcon {
@apply hidden grow-0 shrink-0 basis-8 w-8 h-8 rounded-lg transition-colors duration-200 ease-in-out bg-white border border-gray-200 hover:bg-gray-100 bg-center bg-no-repeat;
background-size: 16px;
background-image: url('~@/assets/delete.svg');
}
.listItem:hover .deleteDatasetIcon {
@apply block;
}
.listItemDescription {
@apply mb-3 px-[14px] h-9 text-xs leading-normal text-gray-500 line-clamp-2;
}
.listItemDescription.noClip {
@apply line-clamp-none;
}
.listItemFooter {
@apply flex items-center flex-wrap min-h-[42px] px-[14px] pt-2 pb-[10px];
}
.listItemFooter.datasetCardFooter {
@apply flex items-center gap-4 text-xs text-gray-500;
}
.listItemStats {
@apply flex items-center gap-1;
}
.listItemFooterIcon {
@apply block w-3 h-3 bg-center bg-contain;
}
.solidChatIcon {
background-image: url("./apps/assets/chat-solid.svg");
}
.solidCompletionIcon {
background-image: url("./apps/assets/completion-solid.svg");
}
.newItemCardHeading {
@apply transition-colors duration-200 ease-in-out;
}
.listItem:hover .newItemCardHeading {
@apply text-primary-600;
}
.listItemLink {
@apply inline-flex items-center gap-1 text-xs text-gray-400 transition-colors duration-200 ease-in-out;
}
.listItem:hover .listItemLink {
@apply text-primary-600;
}
.linkIcon {
@apply block w-[13px] h-[13px] bg-center bg-contain;
background-image: url("./apps/assets/link.svg");
}
.linkIcon.grayLinkIcon {
background-image: url("./apps/assets/link-gray.svg");
}
.listItem:hover .grayLinkIcon {
background-image: url("./apps/assets/link.svg");
}
.rightIcon {
@apply block w-[13px] h-[13px] bg-center bg-contain;
background-image: url("./apps/assets/right-arrow.svg");
}
.socialMediaLink {
@apply flex items-center justify-center w-8 h-8 cursor-pointer hover:opacity-80 transition-opacity duration-200 ease-in-out;
}
.socialMediaIcon {
@apply block w-6 h-6 bg-center bg-contain;
}
/* #region new app dialog */
.newItemCaption {
@apply inline-flex items-center mb-2 text-sm font-medium;
}
/* #endregion new app dialog */
.unavailable {
@apply opacity-50;
}
.listItem:hover .unavailable {
@apply opacity-100;
}

View File

@ -12,23 +12,17 @@ import {
RiFileUploadLine,
} from '@remixicon/react'
import AppIcon from '../base/app-icon'
import SwitchAppModal from '../app/switch-app-modal'
import cn from '@/utils/classnames'
import Confirm from '@/app/components/base/confirm'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast'
import AppsContext, { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
import DuplicateAppModal from '@/app/components/app/duplicate-modal'
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
import CreateAppModal from '@/app/components/explore/create-app-modal'
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { getRedirection } from '@/utils/app-redirection'
import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal'
import type { EnvironmentVariable } from '@/app/components/workflow/types'
import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal'
import { fetchWorkflowDraft } from '@/service/workflow'
import ContentDialog from '@/app/components/base/content-dialog'
import Button from '@/app/components/base/button'
@ -36,6 +30,26 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
import Divider from '../base/divider'
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
import dynamic from 'next/dynamic'
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
ssr: false,
})
const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
ssr: false,
})
const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), {
ssr: false,
})
const Confirm = dynamic(() => import('@/app/components/base/confirm'), {
ssr: false,
})
const UpdateDSLModal = dynamic(() => import('@/app/components/workflow/update-dsl-modal'), {
ssr: false,
})
const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), {
ssr: false,
})
export type IAppInfoProps = {
expand: boolean
@ -71,6 +85,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
icon_background,
description,
use_icon_as_answer_icon,
max_active_requests,
}) => {
if (!appDetail)
return
@ -83,6 +98,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
icon_background,
description,
use_icon_as_answer_icon,
max_active_requests,
})
setShowEditModal(false)
notify({
@ -350,6 +366,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
appDescription={appDetail.description}
appMode={appDetail.mode}
appUseIconAsAnswerIcon={appDetail.use_icon_as_answer_icon}
max_active_requests={appDetail.max_active_requests ?? null}
show={showEditModal}
onConfirm={onEdit}
onHide={() => setShowEditModal(false)}

View File

@ -6,6 +6,7 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import dayjs from 'dayjs'
import relativeTime from 'dayjs/plugin/relativeTime'
import {
RiArrowDownSLine,
RiArrowRightSLine,
@ -48,6 +49,7 @@ import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/acces
import { AccessMode } from '@/models/access-control'
import { fetchAppDetail } from '@/service/apps'
import { useGlobalPublicStore } from '@/context/global-public-context'
dayjs.extend(relativeTime)
export type AppPublisherProps = {
disabled?: boolean
@ -116,6 +118,7 @@ const AppPublisher = ({
}
}, [appAccessSubjects, appDetail])
const language = useGetLanguage()
const formatTimeFromNow = useCallback((time: number) => {
return dayjs(time).locale(language === 'zh_Hans' ? 'zh-cn' : language.replace('_', '-')).fromNow()
}, [language])
@ -180,8 +183,7 @@ const AppPublisher = ({
if (publishDisabled || published)
return
handlePublish()
},
{ exactMatch: true, useCapture: true })
}, { exactMatch: true, useCapture: true })
return (
<>

View File

@ -20,8 +20,8 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, .
target='_blank'
rel='noreferrer'
className={classNames(
'flex justify-start items-center gap-2 py-2 px-2.5 bg-background-section-burn rounded-lg text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
disabled ? 'shadow-xs opacity-30 cursor-not-allowed' : 'text-text-secondary hover:bg-state-accent-hover hover:text-text-accent cursor-pointer',
'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent',
className,
)}
onClick={handleClick}

View File

@ -18,6 +18,7 @@ import AppIcon from '@/app/components/base/app-icon'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Toast from '@/app/components/base/toast'
import ConfigContext from '@/context/debug-configuration'
import type { AgentTool } from '@/types/app'
import { type Collection, CollectionType } from '@/app/components/tools/types'
@ -25,6 +26,8 @@ import { MAX_TOOLS_NUM } from '@/config'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
import { updateBuiltInToolCredential } from '@/service/tools'
import cn from '@/utils/classnames'
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
@ -54,7 +57,13 @@ const AgentTools: FC = () => {
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
const currentCollection = useMemo(() => {
if (!currentTool) return null
const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type)
return collection
}, [currentTool, collectionList])
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(
collection =>
@ -91,6 +100,17 @@ const AgentTools: FC = () => {
formattingChangedDispatcher()
}
const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name)
if (tool)
(tool as AgentTool).notAuthor = false
})
setModelConfig(newModelConfig)
setIsShowSettingTool(false)
formattingChangedDispatcher()
}
const [isDeleting, setIsDeleting] = useState<number>(-1)
const getToolValue = (tool: ToolDefaultValue) => {
return {
@ -124,20 +144,6 @@ const AgentTools: FC = () => {
return item.provider_name
}
const handleAuthorizationItemClick = useCallback((credentialId: string) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id)
if (tool)
(tool as AgentTool).credential_id = credentialId
})
setCurrentTool({
...currentTool,
credential_id: credentialId,
} as any)
setModelConfig(newModelConfig)
formattingChangedDispatcher()
}, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher])
return (
<>
<Panel
@ -203,7 +209,6 @@ const AgentTools: FC = () => {
<span className='text-text-tertiary'>{item.tool_label}</span>
{!item.isDeleted && (
<Tooltip
needsDelay
popupContent={
<div className='w-[180px]'>
<div className='mb-1.5 text-text-secondary'>{item.tool_name}</div>
@ -226,7 +231,6 @@ const AgentTools: FC = () => {
<div className='mr-2 flex items-center'>
<Tooltip
popupContent={t('tools.toolRemoved')}
needsDelay
>
<div className='mr-1 cursor-pointer rounded-md p-1 hover:bg-black/5'>
<AlertTriangle className='h-4 w-4 text-[#F79009]' />
@ -253,7 +257,6 @@ const AgentTools: FC = () => {
{!item.notAuthor && (
<Tooltip
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
needsDelay
>
<div className='cursor-pointer rounded-md p-1 hover:bg-black/5' onClick={() => {
setCurrentTool(item)
@ -296,7 +299,7 @@ const AgentTools: FC = () => {
{item.notAuthor && (
<Button variant='secondary' size='small' onClick={() => {
setCurrentTool(item)
setIsShowSettingTool(true)
setShowSettingAuth(true)
}}>
{t('tools.notAuthorized')}
<Indicator className='ml-2' color='orange' />
@ -316,8 +319,21 @@ const AgentTools: FC = () => {
isModel={currentTool?.collection?.type === CollectionType.model}
onSave={handleToolSettingChange}
onHide={() => setIsShowSettingTool(false)}
credentialId={currentTool?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
)}
{isShowSettingAuth && (
<ConfigCredential
collection={currentCollection as any}
onCancel={() => setShowSettingAuth(false)}
onSaved={async (value) => {
await updateBuiltInToolCredential((currentCollection as any).name, value)
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
handleToolAuthSetting(currentTool)
setShowSettingAuth(false)
}}
/>
)}
</>

View File

@ -14,6 +14,7 @@ import Icon from '@/app/components/plugins/card/base/card-icon'
import OrgInfo from '@/app/components/plugins/card/base/org-info'
import Description from '@/app/components/plugins/card/base/description'
import TabSlider from '@/app/components/base/tab-slider-plain'
import Button from '@/app/components/base/button'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
@ -24,10 +25,6 @@ import I18n from '@/context/i18n'
import { getLanguage } from '@/i18n/language'
import cn from '@/utils/classnames'
import type { ToolWithProvider } from '@/app/components/workflow/types'
import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
type Props = {
showBackButton?: boolean
@ -39,8 +36,6 @@ type Props = {
readonly?: boolean
onHide: () => void
onSave?: (value: Record<string, any>) => void
credentialId?: string
onAuthorizationItemClick?: (id: string) => void
}
const SettingBuiltInTool: FC<Props> = ({
@ -53,8 +48,6 @@ const SettingBuiltInTool: FC<Props> = ({
readonly,
onHide,
onSave,
credentialId,
onAuthorizationItemClick,
}) => {
const { locale } = useContext(I18n)
const language = getLanguage(locale)
@ -204,20 +197,8 @@ const SettingBuiltInTool: FC<Props> = ({
</div>
<div className='system-md-semibold mt-1 text-text-primary'>{currTool?.label[language]}</div>
{!!currTool?.description[language] && (
<Description className='mb-2 mt-3 h-auto' text={currTool.description[language]} descriptionLineRows={2}></Description>
<Description className='mt-3' text={currTool.description[language]} descriptionLineRows={2}></Description>
)}
{
collection.allow_delete && collection.type === CollectionType.builtIn && (
<PluginAuthInAgent
pluginPayload={{
provider: collection.name,
category: AuthCategory.tool,
}}
credentialId={credentialId}
onAuthorizationItemClick={onAuthorizationItemClick}
/>
)
}
</div>
{/* form */}
<div className='h-full'>

View File

@ -177,7 +177,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
<div className='flex justify-between border-t border-divider-subtle p-4 pt-3'>
<Button className='w-[72px]' onClick={onClear}>{t('common.operation.clear')}</Button>
{canNotRun && (
<Tooltip popupContent={t('appDebug.otherError.promptNoBeEmpty')} needsDelay>
<Tooltip popupContent={t('appDebug.otherError.promptNoBeEmpty')}>
<Button
variant="primary"
disabled={canNotRun}

View File

@ -65,6 +65,44 @@ const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => {
export default AppTypeSelector
type AppTypeIconProps = {
type: AppMode
style?: React.CSSProperties
className?: string
wrapperClassName?: string
}
export const AppTypeIcon = React.memo(({ type, className, wrapperClassName, style }: AppTypeIconProps) => {
const wrapperClassNames = cn('inline-flex h-5 w-5 items-center justify-center rounded-md border border-divider-regular', wrapperClassName)
const iconClassNames = cn('h-3.5 w-3.5 text-components-avatar-shape-fill-stop-100', className)
if (type === 'chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-blue-solid')}>
<ChatBot className={iconClassNames} />
</div>
}
if (type === 'agent-chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-violet-solid')}>
<Logic className={iconClassNames} />
</div>
}
if (type === 'advanced-chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-blue-light-solid')}>
<BubbleTextMod className={iconClassNames} />
</div>
}
if (type === 'workflow') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-indigo-solid')}>
<RiExchange2Fill className={iconClassNames} />
</div>
}
if (type === 'completion') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-teal-solid')}>
<ListSparkle className={iconClassNames} />
</div>
}
return null
})
function AppTypeSelectTrigger({ values }: { values: AppSelectorProps['value'] }) {
const { t } = useTranslation()
if (!values || values.length === 0) {
@ -108,44 +146,6 @@ function AppTypeSelectorItem({ checked, type, onClick }: AppTypeSelectorItemProp
</li>
}
type AppTypeIconProps = {
type: AppMode
style?: React.CSSProperties
className?: string
wrapperClassName?: string
}
export function AppTypeIcon({ type, className, wrapperClassName, style }: AppTypeIconProps) {
const wrapperClassNames = cn('inline-flex h-5 w-5 items-center justify-center rounded-md border border-divider-regular', wrapperClassName)
const iconClassNames = cn('h-3.5 w-3.5 text-components-avatar-shape-fill-stop-100', className)
if (type === 'chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-blue-solid')}>
<ChatBot className={iconClassNames} />
</div>
}
if (type === 'agent-chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-violet-solid')}>
<Logic className={iconClassNames} />
</div>
}
if (type === 'advanced-chat') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-blue-light-solid')}>
<BubbleTextMod className={iconClassNames} />
</div>
}
if (type === 'workflow') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-indigo-solid')}>
<RiExchange2Fill className={iconClassNames} />
</div>
}
if (type === 'completion') {
return <div style={style} className={cn(wrapperClassNames, 'bg-components-icon-bg-teal-solid')}>
<ListSparkle className={iconClassNames} />
</div>
}
return null
}
type AppTypeLabelProps = {
type: AppMode
className?: string

Some files were not shown because too many files have changed in this diff Show More