Merge branch 'main' into feat/memory-orchestration-be

This commit is contained in:
Stream
2025-09-23 17:43:52 +08:00
59 changed files with 917 additions and 753 deletions

View File

@ -1,20 +1,11 @@
import logging
import psycogreen.gevent as pscycogreen_gevent # type: ignore import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore from grpc.experimental import gevent as grpc_gevent # type: ignore
_logger = logging.getLogger(__name__)
def _log(message: str):
_logger.debug(message)
# grpc gevent # grpc gevent
grpc_gevent.init_gevent() grpc_gevent.init_gevent()
_log("gRPC patched with gevent.") print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg() pscycogreen_gevent.patch_psycopg()
_log("psycopg2 patched with gevent.") print("psycopg2 patched with gevent.", flush=True) # noqa: T201
from app import app, celery from app import app, celery

View File

@ -1448,41 +1448,52 @@ def transform_datasource_credentials():
notion_credentials_tenant_mapping[tenant_id] = [] notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential) notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if notion_plugin_id not in installed_plugins_ids: try:
if notion_plugin_unique_identifier: # check notion plugin is installed
# install notion plugin installed_plugins = installer_manager.list_plugins(tenant_id)
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
auth_count = 0 if notion_plugin_id not in installed_plugins_ids:
for notion_tenant_credential in notion_tenant_credentials: if notion_plugin_unique_identifier:
auth_count += 1 # install notion plugin
# get credential oauth params PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
access_token = notion_tenant_credential.access_token auth_count = 0
# notion info for notion_tenant_credential in notion_tenant_credentials:
notion_info = notion_tenant_credential.source_info auth_count += 1
workspace_id = notion_info.get("workspace_id") # get credential oauth params
workspace_name = notion_info.get("workspace_name") access_token = notion_tenant_credential.access_token
workspace_icon = notion_info.get("workspace_icon") # notion info
new_credentials = { notion_info = notion_tenant_credential.source_info
"integration_secret": encrypter.encrypt_token(tenant_id, access_token), workspace_id = notion_info.get("workspace_id")
"workspace_id": workspace_id, workspace_name = notion_info.get("workspace_name")
"workspace_name": workspace_name, workspace_icon = notion_info.get("workspace_icon")
"workspace_icon": workspace_icon, new_credentials = {
} "integration_secret": encrypter.encrypt_token(tenant_id, access_token),
datasource_provider = DatasourceProvider( "workspace_id": workspace_id,
provider="notion_datasource", "workspace_name": workspace_name,
tenant_id=tenant_id, "workspace_icon": workspace_icon,
plugin_id=notion_plugin_id, }
auth_type=oauth_credential_type.value, datasource_provider = DatasourceProvider(
encrypted_credentials=new_credentials, provider="notion_datasource",
name=f"Auth {auth_count}", tenant_id=tenant_id,
avatar_url=workspace_icon or "default", plugin_id=notion_plugin_id,
is_default=False, auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
) )
db.session.add(datasource_provider) continue
deal_notion_count += 1
db.session.commit() db.session.commit()
# deal firecrawl credentials # deal firecrawl credentials
deal_firecrawl_count = 0 deal_firecrawl_count = 0
@ -1495,37 +1506,48 @@ def transform_datasource_credentials():
firecrawl_credentials_tenant_mapping[tenant_id] = [] firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if firecrawl_plugin_id not in installed_plugins_ids: try:
if firecrawl_plugin_unique_identifier: # check firecrawl plugin is installed
# install firecrawl plugin installed_plugins = installer_manager.list_plugins(tenant_id)
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0 auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials: for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1 auth_count += 1
# get credential api key # get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials) credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url") base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = { new_credentials = {
"firecrawl_api_key": api_key, "firecrawl_api_key": api_key,
"base_url": base_url, "base_url": base_url,
} }
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
provider="firecrawl", provider="firecrawl",
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id, plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value, auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials, encrypted_credentials=new_credentials,
name=f"Auth {auth_count}", name=f"Auth {auth_count}",
avatar_url="default", avatar_url="default",
is_default=False, is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
) )
db.session.add(datasource_provider) continue
deal_firecrawl_count += 1
db.session.commit() db.session.commit()
# deal jina credentials # deal jina credentials
deal_jina_count = 0 deal_jina_count = 0
@ -1538,36 +1560,45 @@ def transform_datasource_credentials():
jina_credentials_tenant_mapping[tenant_id] = [] jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential) jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if jina_plugin_id not in installed_plugins_ids: try:
if jina_plugin_unique_identifier: # check jina plugin is installed
# install jina plugin installed_plugins = installer_manager.list_plugins(tenant_id)
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0 auth_count = 0
for jina_tenant_credential in jina_tenant_credentials: for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1 auth_count += 1
# get credential api key # get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials) credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = { new_credentials = {
"integration_secret": api_key, "integration_secret": api_key,
} }
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
provider="jina", provider="jina",
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=jina_plugin_id, plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value, auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials, encrypted_credentials=new_credentials,
name=f"Auth {auth_count}", name=f"Auth {auth_count}",
avatar_url="default", avatar_url="default",
is_default=False, is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
except Exception as e:
click.echo(
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
) )
db.session.add(datasource_provider) continue
deal_jina_count += 1
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -5,7 +5,7 @@ import logging
import os import os
import time import time
import requests import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,10 +30,10 @@ class NacosHttpClient:
params = {} params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e: except httpx.RequestError as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password} params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login" url = "http://" + self.server + "/nacos/v1/auth/login"
try: try:
resp = requests.request("POST", url, headers=None, params=params) resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status() resp.raise_for_status()
response_data = resp.json() response_data = resp.json()
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e: except httpx.RequestError as e:
error_text = e.response.text if e.response else str(e) error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import requests import httpx
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from packaging import version from packaging import version
@ -57,7 +57,11 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args["current_version"]

View File

@ -90,29 +90,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# if only single iteration run is requested # Handle single iteration or single loop run
graph_runtime_state = GraphRuntimeState( graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
workflow_execution_id=str(uuid.uuid4()), workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
) )
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )
def single_loop_generate( def single_loop_generate(
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )
def _generate_worker( def _generate_worker(

View File

@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close() db.session.close()
# if only single iteration run is requested # if only single iteration run is requested
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState( # Handle single iteration or single loop run
variable_pool=VariablePool.empty(), graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
# if only single iteration run is requested # if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# if only single iteration run is requested graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -1,3 +1,4 @@
import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
return graph return graph
def _get_graph_and_variable_pool_of_single_iteration( def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration # filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
] ]
graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs] node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration # filter edges only in the specified node type
edge_configs = [ edge_configs = [
edge edge
for edge in graph_config.get("edges", []) for edge in graph_config.get("edges", [])
@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
# fetch node config from node id # fetch node config from node id
iteration_node_config = None target_node_config = None
for node in node_configs: for node in node_configs:
if node.get("id") == node_id: if node.get("id") == node_id:
iteration_node_config = node target_node_config = node
break break
if not iteration_node_config: if not target_node_config:
raise ValueError("iteration node id not found in workflow graph") raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type")) node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1") node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool # Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = VariablePool( variable_pool = graph_runtime_state.variable_pool
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config graph_config=workflow.graph_dict, config=target_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
return graph, variable_pool return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop( def _get_graph_and_variable_pool_of_single_loop(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single loop Get variable pool of single loop
""" """
# fetch workflow graph return self._get_graph_and_variable_pool_for_single_node_run(
graph_config = workflow.graph_dict workflow=workflow,
if not graph_config: node_id=node_id,
raise ValueError("workflow graph not found") user_inputs=user_inputs,
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
) )
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event

View File

@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
""" """
Get custom provider record. Get custom provider record.
""" """
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where( stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names), Provider.provider_name.in_(self._get_provider_names()),
) )
return session.execute(stmt).scalar_one_or_none() return session.execute(stmt).scalar_one_or_none()
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
""" """
stmt = select(ProviderCredential.id).where( stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name, ProviderCredential.credential_name == credential_name,
) )
if exclude_id: if exclude_id:
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
try: try:
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
) )
credential_record = s.execute(stmt).scalar_one_or_none() credential_record = s.execute(stmt).scalar_one_or_none()
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
session=session, session=session,
query_factory=lambda: select(ProviderCredential).where( query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
), ),
) )
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
session=session, session=session,
query_factory=lambda: select(ProviderModelCredential).where( query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
), ),
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e)) logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1" return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None): def create_provider_credential(self, credentials: dict, credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
# Get the credential record to update # Get the credential record to update
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id # Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where( stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source, LoadBalancingModelConfig.credential_source_type == credential_source,
) )
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
# Get the credential record to update # Get the credential record to update
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs # Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider", LoadBalancingModelConfig.credential_source_type == "provider",
) )
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record # if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where( count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
available_credentials_count = session.execute(count_stmt).scalar() or 0 available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record) session.delete(credential_record)
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
credential_record = session.execute(stmt).scalar_one_or_none() credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record: if not credential_record:
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
""" """
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name, ProviderModelCredential.credential_name == credential_name,
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model", LoadBalancingModelConfig.credential_source_type == "custom_model",
) )
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record # if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where( count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
""" """
Get provider model setting. Get provider model setting.
""" """
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where( stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model, ProviderModelSetting.model_name == model,
) )
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
return return
def _switch(s: Session): def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(TenantPreferredModelProvider).where( stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names), TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
) )
preferred_model_provider = s.execute(stmt).scalars().first() preferred_model_provider = s.execute(stmt).scalars().first()

View File

@ -8,7 +8,7 @@ from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
import requests import httpx
from opentelemetry import trace as trace_api from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
@ -65,13 +65,13 @@ class TraceClient:
def api_check(self): def api_check(self):
try: try:
response = requests.head(self.endpoint, timeout=5) response = httpx.head(self.endpoint, timeout=5)
if response.status_code == 405: if response.status_code == 405:
return True return True
else: else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False return False
except requests.RequestException as e: except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e)) logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}") raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -513,6 +513,21 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def _get_provider_names(provider_name: str) -> list[str]:
"""
provider_name: `openai` or `langgenius/openai/openai`
return: [`openai`, `langgenius/openai/openai`]
"""
provider_names = [provider_name]
model_provider_id = ModelProviderID(provider_name)
if model_provider_id.is_langgenius():
if "/" in provider_name:
provider_names.append(model_provider_id.provider_name)
else:
provider_names.append(str(model_provider_id))
return provider_names
@staticmethod @staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
""" """
@ -525,7 +540,10 @@ class ProviderManager:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
stmt = ( stmt = (
select(ProviderCredential) select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) .where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
)
.order_by(ProviderCredential.created_at.desc()) .order_by(ProviderCredential.created_at.desc())
) )
@ -554,7 +572,7 @@ class ProviderManager:
select(ProviderModelCredential) select(ProviderModelCredential)
.where( .where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name, ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
ProviderModelCredential.model_name == model_name, ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type, ProviderModelCredential.model_type == model_type,
) )

View File

@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False) completed: bool = Field(default=False)
aborted: bool = Field(default=False) aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list) exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@ -103,7 +104,8 @@ class GraphExecution:
completed: bool = False completed: bool = False
aborted: bool = False aborted: bool = False
error: Exception | None = None error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict) node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None: def start(self) -> None:
"""Mark the graph execution as started.""" """Mark the graph execution as started."""
@ -172,6 +174,7 @@ class GraphExecution:
completed=self.completed, completed=self.completed,
aborted=self.aborted, aborted=self.aborted,
error=_serialize_error(self.error), error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states, node_executions=node_states,
) )
@ -195,6 +198,7 @@ class GraphExecution:
self.completed = state.completed self.completed = state.completed
self.aborted = state.aborted self.aborted = state.aborted
self.error = _deserialize_error(state.error) self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = { self.node_executions = {
item.node_id: NodeExecution( item.node_id: NodeExecution(
node_id=item.node_id, node_id=item.node_id,
@ -205,3 +209,7 @@ class GraphExecution:
) )
for item in state.node_executions for item in state.node_executions
} }
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@ -3,11 +3,12 @@ Event handler implementations for different event types.
""" """
import logging import logging
from collections.abc import Mapping
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import TYPE_CHECKING, final from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
@ -122,13 +123,15 @@ class EventHandler:
""" """
# Track execution in domain model # Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id) node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering # Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id) self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event # Collect the event only for the first attempt; retries remain silent
self._event_collector.collect(event) if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register @_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None: def _(self, event: NodeRunStreamChunkEvent) -> None:
@ -161,7 +164,7 @@ class EventHandler:
node_execution.mark_taken() node_execution.mark_taken()
# Store outputs in variable pool # Store outputs in variable pool
self._store_node_outputs(event) self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events # Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event) streaming_events = self._response_coordinator.intercept_event(event)
@ -191,7 +194,7 @@ class EventHandler:
# Handle response node outputs # Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE: if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event) self._update_response_outputs(event.node_run_result.outputs)
# Collect the event # Collect the event
self._event_collector.collect(event) self._event_collector.collect(event)
@ -207,6 +210,7 @@ class EventHandler:
# Update domain model # Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error) node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event) result = self._error_handler.handle_node_failure(event)
@ -227,10 +231,40 @@ class EventHandler:
Args: Args:
event: The node exception event event: The node exception event
""" """
# Node continues via fail-branch, so it's technically "succeeded" # Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken() node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register @_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None: def _(self, event: NodeRunRetryEvent) -> None:
""" """
@ -242,21 +276,31 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry() node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: # Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
""" """
Store node outputs in the variable pool. Store node outputs in the variable pool.
Args: Args:
event: The node succeeded event containing outputs event: The node succeeded event containing outputs
""" """
for variable_name, variable_value in event.node_run_result.outputs.items(): for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes.""" """Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs # TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state. # in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items(): for key, value in outputs.items():
if key == "answer": if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "") existing = self._graph_runtime_state.get_output("answer", "")
if existing: if existing:

View File

@ -5,6 +5,7 @@ Unified event manager for collecting and emitting events.
import threading import threading
import time import time
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager
from typing import final from typing import final
from core.workflow.graph_events import GraphEngineEvent from core.workflow.graph_events import GraphEngineEvent
@ -51,43 +52,23 @@ class ReadWriteLock:
"""Release a write lock.""" """Release a write lock."""
self._read_ready.release() self._read_ready.release()
def read_lock(self) -> "ReadLockContext": @contextmanager
def read_lock(self):
"""Return a context manager for read locking.""" """Return a context manager for read locking."""
return ReadLockContext(self) self.acquire_read()
try:
yield
finally:
self.release_read()
def write_lock(self) -> "WriteLockContext": @contextmanager
def write_lock(self):
"""Return a context manager for write locking.""" """Return a context manager for write locking."""
return WriteLockContext(self) self.acquire_write()
try:
yield
@final finally:
class ReadLockContext: self.release_write()
"""Context manager for read locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()
@final
class WriteLockContext:
"""Context manager for write locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()
@final @final

View File

@ -23,6 +23,7 @@ from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
@ -260,12 +261,23 @@ class GraphEngine:
if self._graph_execution.error: if self._graph_execution.error:
raise self._graph_execution.error raise self._graph_execution.error
else: else:
yield GraphRunSucceededEvent( outputs = self._graph_runtime_state.outputs
outputs=self._graph_runtime_state.outputs, exceptions_count = self._graph_execution.exceptions_count
) if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)
except Exception as e: except Exception as e:
yield GraphRunFailedEvent(error=str(e)) yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise raise
finally: finally:

View File

@ -15,6 +15,7 @@ from core.workflow.graph_events import (
GraphEngineEvent, GraphEngineEvent,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs: if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error) self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0: if event.exceptions_count > 0:
@ -138,6 +146,12 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events # Node-level events
# Retry before Started because Retry subclasses Started;
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1 self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
@ -167,11 +181,6 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id) self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error) self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam # Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else "" final_indicator = " (FINAL)" if event.is_final else ""

View File

@ -19,6 +19,7 @@ from core.workflow.enums import (
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
from core.workflow.node_events import ( from core.workflow.node_events import (
@ -372,43 +373,16 @@ class IterationNode(Node):
variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,
} }
iteration_node_ids = set()
# init graph # Find all nodes that belong to this loop
from core.workflow.entities import GraphInitParams, GraphRuntimeState nodes = graph_config.get("nodes", [])
from core.workflow.graph import Graph for node in nodes:
from core.workflow.nodes.node_factory import DifyNodeFactory node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
# Create minimal GraphInitParams for static analysis in_iteration_node_id = node.get("id")
graph_init_params = GraphInitParams( if in_iteration_node_id:
tenant_id="", iteration_node_ids.add(in_iteration_node_id)
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping # Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
@ -444,9 +418,7 @@ class IterationNode(Node):
variable_mapping.update(sub_node_variable_mapping) variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration # remove variable out from iteration
variable_mapping = { variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
return variable_mapping return variable_mapping
@ -485,7 +457,7 @@ class IterationNode(Node):
if isinstance(event, GraphNodeEventBase): if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index) self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event yield event
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector) result = variable_pool.get(self._node_data.output_selector)
if result is None: if result is None:
outputs.append(None) outputs.append(None)

View File

@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting. Retrieval Setting.
""" """
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
top_k: int top_k: int
score_threshold: float | None = 0.5 score_threshold: float | None = 0.5
score_threshold_enabled: bool = False score_threshold_enabled: bool = False

View File

@ -1,3 +1,4 @@
import contextlib
import json import json
import logging import logging
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
@ -127,11 +128,13 @@ class LoopNode(Node):
try: try:
reach_break_condition = False reach_break_condition = False
if break_conditions: if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions( with contextlib.suppress(ValueError):
variable_pool=self.graph_runtime_state.variable_pool, _, _, reach_break_condition = condition_processor.process_conditions(
conditions=break_conditions, variable_pool=self.graph_runtime_state.variable_pool,
operator=logical_operator, conditions=break_conditions,
) operator=logical_operator,
)
if reach_break_condition: if reach_break_condition:
loop_count = 0 loop_count = 0
cost_tokens = 0 cost_tokens = 0
@ -295,42 +298,11 @@ class LoopNode(Node):
variable_mapping = {} variable_mapping = {}
# init graph # Extract loop node IDs statically from graph_config
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis # Get node configs from graph_config
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items(): for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id: if sub_node_config.get("data", {}).get("loop_id") != node_id:
@ -371,12 +343,35 @@ class LoopNode(Node):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop # remove variable out from loop
variable_mapping = { variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
return variable_mapping return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod @staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value.""" """Get the appropriate segment type for a constant value."""

View File

@ -33,6 +33,7 @@ file_fields = {
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
"preview_url": fields.String, "preview_url": fields.String,
"source_url": fields.String,
} }

View File

@ -1,10 +1,32 @@
import psycogreen.gevent as pscycogreen_gevent # type: ignore import psycogreen.gevent as pscycogreen_gevent # type: ignore
from gevent import events as gevent_events
from grpc.experimental import gevent as grpc_gevent # type: ignore from grpc.experimental import gevent as grpc_gevent # type: ignore
# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as
# grpc_gevent.init_gevent must be called after patching stdlib.
# Gunicorn calls `post_init` before applying monkey patch.
# Use `post_init` to setup gRPC gevent support would cause deadlock and
# some other weird issues.
#
# ref:
# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py
# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668
# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613
def post_fork(server, worker):
def post_patch(event):
# this function is only called for gevent worker.
# from gevent docs (https://www.gevent.org/api/gevent.monkey.html):
# You can also subscribe to the events to provide additional patching beyond what gevent distributes, either for
# additional standard library modules, or for third-party packages. The suggested time to do this patching is in
# the subscriber for gevent.events.GeventDidPatchBuiltinModulesEvent.
if not isinstance(event, gevent_events.GeventDidPatchBuiltinModulesEvent):
return
# grpc gevent # grpc gevent
grpc_gevent.init_gevent() grpc_gevent.init_gevent()
server.log.info("gRPC patched with gevent.") print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg() pscycogreen_gevent.patch_psycopg()
server.log.info("psycopg2 patched with gevent.") print("psycopg2 patched with gevent.", flush=True) # noqa: T201
gevent_events.subscribers.append(post_patch)

View File

@ -1,7 +1,7 @@
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
import requests import httpx
@dataclass @dataclass
@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"} headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers) response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
user_info = response.json() user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json() email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers) response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@ -1,7 +1,7 @@
import urllib.parse import urllib.parse
from typing import Any from typing import Any
import requests import httpx
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select from sqlalchemy import select
@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret) auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()
results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))
@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
message = response_json.get("message", "unknown error") message = response_json.get("message", "unknown error")
@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=self._NOTION_BOT_USER, headers=headers) response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json() response_json = response.json()
if "object" in response_json and response_json["object"] == "user": if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"] user_type = response_json["type"]
@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()
results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))

View File

@ -47,7 +47,7 @@ def upgrade():
sa.Column('plugin_id', sa.String(length=255), nullable=False), sa.Column('plugin_id', sa.String(length=255), nullable=False),
sa.Column('auth_type', sa.String(length=255), nullable=False), sa.Column('auth_type', sa.String(length=255), nullable=False),
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=True), sa.Column('avatar_url', sa.Text(), nullable=True),
sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False), sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False), sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@ -35,7 +35,7 @@ class DatasourceProvider(Base):
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default") avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1") expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "2.0.0-beta2" version = "1.9.0"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dependencies = [ dependencies = [

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,7 +1,7 @@
import json import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key} return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers): def _get_request(self, url, headers):
return requests.get(url, headers=headers) return httpx.get(url, headers=headers)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting. Retrieval Setting.
""" """
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
top_k: int top_k: int
score_threshold: float | None = 0.5 score_threshold: float | None = 0.5
score_threshold_enabled: bool = False score_threshold_enabled: bool = False

View File

@ -1,6 +1,6 @@
import os import os
import requests import httpx
class OperationService: class OperationService:
@ -12,7 +12,7 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)
return response.json() return response.json()

View File

@ -3,7 +3,7 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import requests import httpx
from flask_login import current_user from flask_login import current_user
from core.helper import encrypter from core.helper import encrypter
@ -216,7 +216,7 @@ class WebsiteService:
@classmethod @classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages: if not request.options.crawl_sub_pages:
response = requests.get( response = httpx.get(
f"https://r.jina.ai/{request.url}", f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
@ -224,7 +224,7 @@ class WebsiteService:
raise ValueError("Failed to crawl:") raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")} return {"status": "active", "data": response.json().get("data")}
else: else:
response = requests.post( response = httpx.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={ json={
"url": request.url, "url": request.url,
@ -287,7 +287,7 @@ class WebsiteService:
@classmethod @classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post( response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
@ -303,7 +303,7 @@ class WebsiteService:
} }
if crawl_status_data["status"] == "completed": if crawl_status_data["status"] == "completed":
response = requests.post( response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@ -362,7 +362,7 @@ class WebsiteService:
@classmethod @classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id: if not job_id:
response = requests.get( response = httpx.get(
f"https://r.jina.ai/{url}", f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
@ -371,7 +371,7 @@ class WebsiteService:
return dict(response.json().get("data", {})) return dict(response.json().get("data", {}))
else: else:
# Get crawl status first # Get crawl status first
status_response = requests.post( status_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
@ -381,7 +381,7 @@ class WebsiteService:
raise ValueError("Crawl job is not completed") raise ValueError("Crawl job is not completed")
# Get processed data # Get processed data
data_response = requests.post( data_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},

View File

@ -1,8 +1,8 @@
import os import os
from typing import Literal from typing import Literal
import httpx
import pytest import pytest
import requests
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -27,13 +27,11 @@ class MockedHttp:
@classmethod @classmethod
def requests_request( def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> requests.Response: ) -> httpx.Response:
""" """
Mocked requests.request Mocked httpx.request
""" """
request = requests.PreparedRequest() request = httpx.Request(method, url)
request.method = method
request.url = url
if url.endswith("/tools"): if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]]( content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools() code=0, message="success", data=cls.list_tools()
@ -41,8 +39,7 @@ class MockedHttp:
else: else:
raise ValueError("") raise ValueError("")
response = requests.Response() response = httpx.Response(status_code=200)
response.status_code = 200
response.request = request response.request = request
response._content = content.encode("utf-8") response._content = content.encode("utf-8")
return response return response
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture @pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch): def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH: if MOCK_SWITCH:
monkeypatch.setattr(requests, "request", MockedHttp.requests_request) monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
def unpatch(): def unpatch():
monkeypatch.undo() monkeypatch.undo()

View File

@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
import os import os
import time import time
import requests import httpx
from clickzetta import connect from clickzetta import connect
@ -66,7 +66,7 @@ def test_dify_api():
max_retries = 30 max_retries = 30
for i in range(max_retries): for i in range(max_retries):
try: try:
response = requests.get(f"{base_url}/console/api/health") response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200: if response.status_code == 200:
print("✓ Dify API is ready") print("✓ Dify API is ready")
break break

View File

@ -173,7 +173,7 @@ class DifyTestContainers:
# Start Dify Plugin Daemon container for plugin management # Start Dify Plugin Daemon container for plugin management
# Dify Plugin Daemon provides plugin lifecycle management and execution # Dify Plugin Daemon provides plugin lifecycle management and execution
logger.info("Initializing Dify Plugin Daemon container...") logger.info("Initializing Dify Plugin Daemon container...")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local") self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
self.dify_plugin_daemon.with_exposed_ports(5002) self.dify_plugin_daemon.with_exposed_ports(5002)
self.dify_plugin_daemon.env = { self.dify_plugin_daemon.env = {
"DB_HOST": db_host, "DB_HOST": db_host,

View File

@ -201,9 +201,9 @@ class TestOAuthCallback:
mock_db.session.rollback = MagicMock() mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception # Import the real requests module to create a proper exception
import requests import httpx
request_exception = requests.exceptions.RequestException("OAuth error") request_exception = httpx.RequestError("OAuth error")
request_exception.response = MagicMock() request_exception.response = MagicMock()
request_exception.response.text = str(exception) request_exception.response.text = str(exception)

View File

@ -0,0 +1,120 @@
"""Tests for graph engine event handlers."""
from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""
class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""
node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")
event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)
return handler, event_manager, graph_execution
def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""
node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)
execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()
start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)
retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)
# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

View File

@ -10,11 +10,18 @@ import time
from hypothesis import HealthCheck, given, settings from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
# Import the test framework from the new module # Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
else: else:
assert result.event_sequence_match is True assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)
llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)
events = list(engine.run())
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

View File

@ -1,65 +0,0 @@
import pytest
pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@ -1,8 +1,8 @@
import urllib.parse import urllib.parse
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post") @patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
): ):
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
), ),
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock() user_response = MagicMock()
user_response.json.return_value = user_data user_response.json.return_value = user_data
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
assert user_info.name == user_data["name"] assert user_info.name == user_data["name"]
assert user_info.email == expected_email assert user_info.email == expected_email
@patch("requests.get") @patch("httpx.get")
def test_should_handle_network_errors(self, mock_get, oauth): def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = requests.exceptions.RequestException("Network error") mock_get.side_effect = httpx.RequestError("Network error")
with pytest.raises(requests.exceptions.RequestException): with pytest.raises(httpx.RequestError):
oauth.get_raw_user_info("test_token") oauth.get_raw_user_info("test_token")
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post") @patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
): ):
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data mock_response.json.return_value = user_data
mock_get.return_value = mock_response mock_get.return_value = mock_response
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exception_type", "exception_type",
[ [
requests.exceptions.HTTPError, httpx.HTTPError,
requests.exceptions.ConnectionError, httpx.ConnectError,
requests.exceptions.Timeout, httpx.TimeoutException,
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type): def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error") mock_response.raise_for_status.side_effect = exception_type("Error")

View File

@ -6,8 +6,8 @@ import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import httpx
import pytest import pytest
import requests
from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@ -26,7 +26,7 @@ class TestAuthIntegration:
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage""" """Test complete authentication flow: request → validation → encryption → storage"""
@ -47,7 +47,7 @@ class TestAuthIntegration:
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http): def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration""" """Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response() mock_http.return_value = self._create_success_response()
@ -97,7 +97,7 @@ class TestAuthIntegration:
assert "another_secret" not in factory_str assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety""" """Test concurrent authentication creation safety"""
@ -142,31 +142,31 @@ class TestAuthIntegration:
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http): def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling""" """Test proper HTTP error handling"""
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 401 mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}' mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized") mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises # PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((requests.exceptions.HTTPError, Exception)): with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials() factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(self, mock_http, mock_session): def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures""" """Test system recovery from network failures"""
mock_http.side_effect = requests.exceptions.RequestException("Network timeout") mock_http.side_effect = httpx.RequestError("Network timeout")
mock_session.add = Mock() mock_session.add = Mock()
mock_session.commit = Mock() mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(requests.exceptions.RequestException): with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called() mock_session.commit.assert_not_called()

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.firecrawl.firecrawl import FirecrawlAuth from services.auth.firecrawl.firecrawl import FirecrawlAuth
@ -64,7 +64,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials) FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -95,7 +95,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
@ -115,7 +115,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@ -134,13 +134,13 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"), (httpx.ConnectError, "Network error"),
(requests.Timeout, "Request timeout"), (httpx.TimeoutException, "Request timeout"),
(requests.ReadTimeout, "Read timeout"), (httpx.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"), (httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message) mock_post.side_effect = exception_type(exception_message)
@ -162,7 +162,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post): def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -179,12 +179,12 @@ class TestFirecrawlAuth:
assert result is True assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds") mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info: with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.jina.jina import JinaAuth from services.auth.jina.jina import JinaAuth
@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials) JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided" assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post): def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"}, json={"url": "https://example.com"},
) )
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_402_error(self, mock_post): def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error""" """Test handling of 402 Payment Required error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_409_error(self, mock_post): def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error""" """Test handling of 409 Conflict error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_500_error(self, mock_post): def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error""" """Test handling of 500 Internal Server Error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post): def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response""" """Test handling of unexpected errors with text response"""
mock_response = MagicMock() mock_response = MagicMock()
@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_without_text(self, mock_post): def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response""" """Test handling of unexpected errors without text response"""
mock_response = MagicMock() mock_response = MagicMock()
@ -130,15 +130,15 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_network_errors(self, mock_post): def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors""" """Test handling of network connection errors"""
mock_post.side_effect = requests.ConnectionError("Network error") mock_post.side_effect = httpx.ConnectError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials) auth = JinaAuth(credentials)
with pytest.raises(requests.ConnectionError): with pytest.raises(httpx.ConnectError):
auth.validate_credentials() auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self): def test_should_not_expose_api_key_in_error_messages(self):

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.watercrawl.watercrawl import WatercrawlAuth from services.auth.watercrawl.watercrawl import WatercrawlAuth
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials) WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@ -126,13 +126,13 @@ class TestWatercrawlAuth:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"), (httpx.ConnectError, "Network error"),
(requests.Timeout, "Request timeout"), (httpx.TimeoutException, "Request timeout"),
(requests.ReadTimeout, "Read timeout"), (httpx.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"), (httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message) mock_get.side_effect = exception_type(exception_message)
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_custom_base_url_in_validation(self, mock_get): def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs""" """Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock() mock_response = MagicMock()
@ -193,12 +193,12 @@ class TestWatercrawlAuth:
# Verify the correct URL was called # Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds") mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info: with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message

4
api/uv.lock generated
View File

@ -1,5 +1,5 @@
version = 1 version = 1
revision = 2 revision = 3
requires-python = ">=3.11, <3.13" requires-python = ">=3.11, <3.13"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12.4' and sys_platform == 'linux'", "python_full_version >= '3.12.4' and sys_platform == 'linux'",
@ -1273,7 +1273,7 @@ wheels = [
[[package]] [[package]]
name = "dify-api" name = "dify-api"
version = "2.0.0b2" version = "1.9.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "arize-phoenix-otel" }, { name = "arize-phoenix-otel" },

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -58,7 +58,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -76,7 +76,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:2.0.0-beta.2 image: langgenius/dify-web:1.9.0
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -177,7 +177,7 @@ services:
# plugin daemon # plugin daemon
plugin_daemon: plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.

View File

@ -20,7 +20,17 @@ services:
ports: ports:
- "${EXPOSE_POSTGRES_PORT:-5432}:5432" - "${EXPOSE_POSTGRES_PORT:-5432}:5432"
healthcheck: healthcheck:
test: [ 'CMD', 'pg_isready', '-h', 'db', '-U', '${PGUSER:-postgres}', '-d', '${POSTGRES_DB:-dify}' ] test:
[
"CMD",
"pg_isready",
"-h",
"db",
"-U",
"${PGUSER:-postgres}",
"-d",
"${POSTGRES_DB:-dify}",
]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -41,7 +51,11 @@ services:
ports: ports:
- "${EXPOSE_REDIS_PORT:-6379}:6379" - "${EXPOSE_REDIS_PORT:-6379}:6379"
healthcheck: healthcheck:
test: [ 'CMD-SHELL', 'redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG' ] test:
[
"CMD-SHELL",
"redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG",
]
# The DifySandbox # The DifySandbox
sandbox: sandbox:
@ -65,13 +79,13 @@ services:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf - ./volumes/sandbox/conf:/conf
healthcheck: healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] test: ["CMD", "curl", "-f", "http://localhost:8194/health"]
networks: networks:
- ssrf_proxy_network - ssrf_proxy_network
# plugin daemon # plugin daemon
plugin_daemon: plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always restart: always
env_file: env_file:
- ./middleware.env - ./middleware.env
@ -143,7 +157,12 @@ services:
volumes: volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] entrypoint:
[
"sh",
"-c",
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
env_file: env_file:
- ./middleware.env - ./middleware.env
environment: environment:

View File

@ -593,7 +593,7 @@ x-shared-env: &shared-api-worker-env
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -622,7 +622,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -649,7 +649,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:2.0.0-beta.2 image: langgenius/dify-api:1.9.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -667,7 +667,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:2.0.0-beta.2 image: langgenius/dify-web:1.9.0
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -768,7 +768,7 @@ services:
# plugin daemon # plugin daemon
plugin_daemon: plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.

View File

@ -1,4 +0,0 @@
GET /console/api/spec/schema-definitions
Host: cloud-rag.dify.dev
authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiNzExMDZhYTQtZWJlMC00NGMzLWI4NWYtMWQ4Mjc5ZTExOGZmIiwiZXhwIjoxNzU2MTkyNDE4LCJpc3MiOiJDTE9VRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.Yx_TMdWVXCp5YEoQ8WR90lRhHHKggxAQvEl5RUnkZuc
###

View File

@ -124,7 +124,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
const handleAddGroup = useCallback(() => { const handleAddGroup = useCallback(() => {
let maxInGroupName = 1 let maxInGroupName = 1
inputs.advanced_settings.groups.forEach((item) => { inputs.advanced_settings.groups.forEach((item) => {
const match = item.group_name.match(/(\d+)$/) const match = /(\d+)$/.exec(item.group_name)
if (match) { if (match) {
const num = Number.parseInt(match[1], 10) const num = Number.parseInt(match[1], 10)
if (num > maxInGroupName) if (num > maxInGroupName)

View File

@ -1,12 +1,18 @@
@import "preflight.css"; @import "preflight.css";
@tailwind base;
@tailwind components;
@import '../../themes/light.css'; @import '../../themes/light.css';
@import '../../themes/dark.css'; @import '../../themes/dark.css';
@import "../../themes/manual-light.css"; @import "../../themes/manual-light.css";
@import "../../themes/manual-dark.css"; @import "../../themes/manual-dark.css";
@import "../components/base/button/index.css";
@import "../components/base/action-button/index.css";
@import "../components/base/modal/index.css";
@tailwind base;
@tailwind components;
html { html {
color-scheme: light; color-scheme: light;
} }
@ -680,10 +686,6 @@ button:focus-within {
display: none; display: none;
} }
@import "../components/base/button/index.css";
@import "../components/base/action-button/index.css";
@import "../components/base/modal/index.css";
@tailwind utilities; @tailwind utilities;
@layer utilities { @layer utilities {

View File

@ -91,12 +91,10 @@ const remoteImageURLs = [hasSetWebPrefix ? new URL(`${process.env.NEXT_PUBLIC_WE
/** @type {import('next').NextConfig} */ /** @type {import('next').NextConfig} */
const nextConfig = { const nextConfig = {
basePath: process.env.NEXT_PUBLIC_BASE_PATH || '', basePath: process.env.NEXT_PUBLIC_BASE_PATH || '',
webpack: (config, { dev, isServer }) => { turbopack: {
if (dev) { rules: codeInspectorPlugin({
config.plugins.push(codeInspectorPlugin({ bundler: 'webpack' })) bundler: 'turbopack'
} })
return config
}, },
productionBrowserSourceMaps: false, // enable browser source map generation during the production build productionBrowserSourceMaps: false, // enable browser source map generation during the production build
// Configure pageExtensions to include md and mdx // Configure pageExtensions to include md and mdx
@ -112,6 +110,10 @@ const nextConfig = {
})), })),
}, },
experimental: { experimental: {
optimizePackageImports: [
'@remixicon/react',
'@heroicons/react'
],
}, },
// fix all before production. Now it slow the develop speed. // fix all before production. Now it slow the develop speed.
eslint: { eslint: {

View File

@ -1,6 +1,6 @@
{ {
"name": "dify-web", "name": "dify-web",
"version": "2.0.0-beta2", "version": "1.9.0",
"private": true, "private": true,
"packageManager": "pnpm@10.16.0", "packageManager": "pnpm@10.16.0",
"engines": { "engines": {
@ -19,7 +19,7 @@
"and_qq >= 14.9" "and_qq >= 14.9"
], ],
"scripts": { "scripts": {
"dev": "cross-env NODE_OPTIONS='--inspect' next dev", "dev": "cross-env NODE_OPTIONS='--inspect' next dev --turbopack",
"build": "next build", "build": "next build",
"build:docker": "next build && node scripts/optimize-standalone.js", "build:docker": "next build && node scripts/optimize-standalone.js",
"start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js", "start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js",
@ -203,7 +203,7 @@
"autoprefixer": "^10.4.20", "autoprefixer": "^10.4.20",
"babel-loader": "^10.0.0", "babel-loader": "^10.0.0",
"bing-translate-api": "^4.0.2", "bing-translate-api": "^4.0.2",
"code-inspector-plugin": "^0.18.1", "code-inspector-plugin": "1.2.9",
"cross-env": "^7.0.3", "cross-env": "^7.0.3",
"eslint": "^9.35.0", "eslint": "^9.35.0",
"eslint-config-next": "15.5.0", "eslint-config-next": "15.5.0",

129
web/pnpm-lock.yaml generated
View File

@ -519,8 +519,8 @@ importers:
specifier: ^4.0.2 specifier: ^4.0.2
version: 4.1.0 version: 4.1.0
code-inspector-plugin: code-inspector-plugin:
specifier: ^0.18.1 specifier: 1.2.9
version: 0.18.3 version: 1.2.9
cross-env: cross-env:
specifier: ^7.0.3 specifier: ^7.0.3
version: 7.0.3 version: 7.0.3
@ -1372,6 +1372,24 @@ packages:
'@clack/prompts@0.11.0': '@clack/prompts@0.11.0':
resolution: {integrity: sha512-pMN5FcrEw9hUkZA4f+zLlzivQSeQf5dRGJjSUbvVYDLvpKCdQx5OaknvKzgbtXOizhP+SJJJjqEbOe55uKKfAw==} resolution: {integrity: sha512-pMN5FcrEw9hUkZA4f+zLlzivQSeQf5dRGJjSUbvVYDLvpKCdQx5OaknvKzgbtXOizhP+SJJJjqEbOe55uKKfAw==}
'@code-inspector/core@1.2.9':
resolution: {integrity: sha512-A1w+G73HlTB6S8X6sA6tT+ziWHTAcTyH+7FZ1Sgd3ZLXF/E/jT+hgRbKposjXMwxcbodRc6hBG6UyiV+VxwE6Q==}
'@code-inspector/esbuild@1.2.9':
resolution: {integrity: sha512-DuyfxGupV43CN8YElIqynAniBtE86i037+3OVJYrm3jlJscXzbV98/kOzvu+VJQQvElcDgpgD6C/aGmPvFEiUg==}
'@code-inspector/mako@1.2.9':
resolution: {integrity: sha512-8N+MHdr64AnthLB4v+YGe8/9bgog3BnkxIW/fqX5iVS0X06mF7X1pxfZOD2bABVtv1tW25lRtNs5AgvYJs0vpg==}
'@code-inspector/turbopack@1.2.9':
resolution: {integrity: sha512-UVOUbqU6rpi5eOkrFamKrdeSWb0/OFFJQBaxbgs1RK5V5f4/iVwC5KjO2wkjv8cOGU4EppLfBVSBI1ysOo8S5A==}
'@code-inspector/vite@1.2.9':
resolution: {integrity: sha512-saIokJ3o3SdrHEgTEg1fbbowbKfh7J4mYtu0i1mVfah1b1UfdCF/iFHTEJ6SADMiY47TeNZTg0TQWTlU1AWPww==}
'@code-inspector/webpack@1.2.9':
resolution: {integrity: sha512-9YEykVrOIc0zMV7pyTyZhCprjScjn6gPPmxb4/OQXKCrP2fAm+NB188rg0s95e4sM7U3qRUpPA4NUH5F7Ogo+g==}
'@cspotcode/source-map-support@0.8.1': '@cspotcode/source-map-support@0.8.1':
resolution: {integrity: sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==} resolution: {integrity: sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==}
engines: {node: '>=12'} engines: {node: '>=12'}
@ -4425,11 +4443,8 @@ packages:
resolution: {integrity: sha512-QVb0dM5HvG+uaxitm8wONl7jltx8dqhfU33DcqtOZcLSVIKSDDLDi7+0LbAKiyI8hD9u42m2YxXSkMGWThaecQ==} resolution: {integrity: sha512-QVb0dM5HvG+uaxitm8wONl7jltx8dqhfU33DcqtOZcLSVIKSDDLDi7+0LbAKiyI8hD9u42m2YxXSkMGWThaecQ==}
engines: {iojs: '>= 1.0.0', node: '>= 0.12.0'} engines: {iojs: '>= 1.0.0', node: '>= 0.12.0'}
code-inspector-core@0.18.3: code-inspector-plugin@1.2.9:
resolution: {integrity: sha512-60pT2cPoguMTUYdN1MMpjoPUnuF0ud/u7M2y+Vqit/bniLEit9dySEWAVxLU/Ukc5ILrDeLKEttc6fCMl9RUrA==} resolution: {integrity: sha512-PGp/AQ03vaajimG9rn5+eQHGifrym5CSNLCViPtwzot7FM3MqEkGNqcvimH0FVuv3wDOcP5KvETAUSLf1BE3HA==}
code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-d9oJXZUsnvfTaQDwFmDNA2F+AR/TXIxWg1rr8KGcEskltR2prbZsfuu1z70EAn4khpx0smfi/PvIIwNJQ7FAMw==}
collapse-white-space@2.1.0: collapse-white-space@2.1.0:
resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==} resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==}
@ -5055,9 +5070,6 @@ packages:
esast-util-from-js@2.0.1: esast-util-from-js@2.0.1:
resolution: {integrity: sha512-8Ja+rNJ0Lt56Pcf3TAmpBZjmx8ZcK5Ts4cAzIOjsjevg9oSXJnl6SUQ2EevU8tv3h6ZLWmoKL5H4fgWvdvfETw==} resolution: {integrity: sha512-8Ja+rNJ0Lt56Pcf3TAmpBZjmx8ZcK5Ts4cAzIOjsjevg9oSXJnl6SUQ2EevU8tv3h6ZLWmoKL5H4fgWvdvfETw==}
esbuild-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-FaPt5eFMtW1oXMWqAcqfAJByNagP1V/R9dwDDLQO29JmryMF35+frskTqy+G53whmTaVi19+TCrFqhNbMZH5ZQ==}
esbuild-register@3.6.0: esbuild-register@3.6.0:
resolution: {integrity: sha512-H2/S7Pm8a9CL1uhp9OvjwrBh5Pvx0H8qVOxNu8Wed9Y7qv56MPtq+GGM8RJpq6glYJn9Wspr8uw7l55uyinNeg==} resolution: {integrity: sha512-H2/S7Pm8a9CL1uhp9OvjwrBh5Pvx0H8qVOxNu8Wed9Y7qv56MPtq+GGM8RJpq6glYJn9Wspr8uw7l55uyinNeg==}
peerDependencies: peerDependencies:
@ -6413,8 +6425,8 @@ packages:
resolution: {integrity: sha512-MbjN408fEndfiQXbFQ1vnd+1NoLDsnQW41410oQBXiyXDMYH5z505juWa4KUE1LqxRC7DgOgZDbKLxHIwm27hA==} resolution: {integrity: sha512-MbjN408fEndfiQXbFQ1vnd+1NoLDsnQW41410oQBXiyXDMYH5z505juWa4KUE1LqxRC7DgOgZDbKLxHIwm27hA==}
engines: {node: '>=0.10'} engines: {node: '>=0.10'}
launch-ide@1.0.1: launch-ide@1.2.0:
resolution: {integrity: sha512-U7qBxSNk774PxWq4XbmRe0ThiIstPoa4sMH/OGSYxrFVvg8x3biXcF1fsH6wasDpEmEXMdINUrQhBdwsSgKyMg==} resolution: {integrity: sha512-7nXSPQOt3b2JT52Ge8jp4miFcY+nrUEZxNLWBzrEfjmByDTb9b5ytqMSwGhsNwY6Cntwop+6n7rWIFN0+S8PTw==}
layout-base@1.0.2: layout-base@1.0.2:
resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==} resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==}
@ -8693,9 +8705,6 @@ packages:
vfile@6.0.3: vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vite-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-178H73vbDUHE+JpvfAfioUHlUr7qXCYIEa2YNXtzenFQGOjtae59P1jjcxGfa6pPHEnOoaitb13K+0qxwhi/WA==}
vm-browserify@1.1.2: vm-browserify@1.1.2:
resolution: {integrity: sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==} resolution: {integrity: sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==}
@ -8754,9 +8763,6 @@ packages:
engines: {node: '>= 10.13.0'} engines: {node: '>= 10.13.0'}
hasBin: true hasBin: true
webpack-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-3782rsJhBnRiw0IpR6EqnyGDQoiSq0CcGeLJ52rZXlszYCe8igXtcujq7OhI0byaivWQ1LW7sXKyMEoVpBhq0w==}
webpack-dev-middleware@6.1.3: webpack-dev-middleware@6.1.3:
resolution: {integrity: sha512-A4ChP0Qj8oGociTs6UdlRUGANIGrCDL3y+pmQMc+dSsraXHCatFpmMey4mYELA+juqwUqwQsUgJJISXl1KWmiw==} resolution: {integrity: sha512-A4ChP0Qj8oGociTs6UdlRUGANIGrCDL3y+pmQMc+dSsraXHCatFpmMey4mYELA+juqwUqwQsUgJJISXl1KWmiw==}
engines: {node: '>= 14.15.0'} engines: {node: '>= 14.15.0'}
@ -9993,6 +9999,48 @@ snapshots:
picocolors: 1.1.1 picocolors: 1.1.1
sisteransi: 1.0.5 sisteransi: 1.0.5
'@code-inspector/core@1.2.9':
dependencies:
'@vue/compiler-dom': 3.5.17
chalk: 4.1.2
dotenv: 16.6.1
launch-ide: 1.2.0
portfinder: 1.0.37
transitivePeerDependencies:
- supports-color
'@code-inspector/esbuild@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/mako@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/turbopack@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
'@code-inspector/webpack': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/vite@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
chalk: 4.1.1
transitivePeerDependencies:
- supports-color
'@code-inspector/webpack@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@cspotcode/source-map-support@0.8.1': '@cspotcode/source-map-support@0.8.1':
dependencies: dependencies:
'@jridgewell/trace-mapping': 0.3.9 '@jridgewell/trace-mapping': 0.3.9
@ -12799,7 +12847,7 @@ snapshots:
'@vue/compiler-core@3.5.17': '@vue/compiler-core@3.5.17':
dependencies: dependencies:
'@babel/parser': 7.28.0 '@babel/parser': 7.28.4
'@vue/shared': 3.5.17 '@vue/shared': 3.5.17
entities: 4.5.0 entities: 4.5.0
estree-walker: 2.0.2 estree-walker: 2.0.2
@ -13503,24 +13551,15 @@ snapshots:
co@4.6.0: {} co@4.6.0: {}
code-inspector-core@0.18.3: code-inspector-plugin@1.2.9:
dependencies: dependencies:
'@vue/compiler-dom': 3.5.17 '@code-inspector/core': 1.2.9
'@code-inspector/esbuild': 1.2.9
'@code-inspector/mako': 1.2.9
'@code-inspector/turbopack': 1.2.9
'@code-inspector/vite': 1.2.9
'@code-inspector/webpack': 1.2.9
chalk: 4.1.1 chalk: 4.1.1
dotenv: 16.6.1
launch-ide: 1.0.1
portfinder: 1.0.37
transitivePeerDependencies:
- supports-color
code-inspector-plugin@0.18.3:
dependencies:
chalk: 4.1.1
code-inspector-core: 0.18.3
dotenv: 16.6.1
esbuild-code-inspector-plugin: 0.18.3
vite-code-inspector-plugin: 0.18.3
webpack-code-inspector-plugin: 0.18.3
transitivePeerDependencies: transitivePeerDependencies:
- supports-color - supports-color
@ -14160,12 +14199,6 @@ snapshots:
esast-util-from-estree: 2.0.0 esast-util-from-estree: 2.0.0
vfile-message: 4.0.2 vfile-message: 4.0.2
esbuild-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
esbuild-register@3.6.0(esbuild@0.25.0): esbuild-register@3.6.0(esbuild@0.25.0):
dependencies: dependencies:
debug: 4.4.1 debug: 4.4.1
@ -16020,7 +16053,7 @@ snapshots:
dependencies: dependencies:
language-subtag-registry: 0.3.23 language-subtag-registry: 0.3.23
launch-ide@1.0.1: launch-ide@1.2.0:
dependencies: dependencies:
chalk: 4.1.2 chalk: 4.1.2
dotenv: 16.6.1 dotenv: 16.6.1
@ -18779,12 +18812,6 @@ snapshots:
'@types/unist': 3.0.3 '@types/unist': 3.0.3
vfile-message: 4.0.2 vfile-message: 4.0.2
vite-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
vm-browserify@1.1.2: {} vm-browserify@1.1.2: {}
void-elements@3.1.0: {} void-elements@3.1.0: {}
@ -18855,12 +18882,6 @@ snapshots:
- bufferutil - bufferutil
- utf-8-validate - utf-8-validate
webpack-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
webpack-dev-middleware@6.1.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)): webpack-dev-middleware@6.1.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)):
dependencies: dependencies:
colorette: 2.0.20 colorette: 2.0.20

View File

@ -26,6 +26,9 @@
"paths": { "paths": {
"@/*": [ "@/*": [
"./*" "./*"
],
"~@/*": [
"./*"
] ]
} }
}, },