mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'main' into feat/memory-orchestration-be
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -74,6 +75,9 @@ class DatasetConfigManager:
|
||||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
@ -82,18 +86,23 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
score_threshold_val = dataset_configs.get("score_threshold")
|
||||
reranking_model_val = dataset_configs.get("reranking_model")
|
||||
weights_val = dataset_configs.get("weights")
|
||||
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
@ -101,22 +110,23 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold")
|
||||
if dataset_configs.get("score_threshold_enabled", False)
|
||||
top_k=int(dataset_configs.get("top_k", 4)),
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
@ -134,18 +144,17 @@ class DatasetConfigManager:
|
||||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {}
|
||||
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
@ -166,8 +175,8 @@ class DatasetConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||
config["agent_mode"] = {}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
@ -180,19 +189,22 @@ class DatasetConfigManager:
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
# tools
|
||||
if not config["agent_mode"].get("tools"):
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
# strategy
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER.value,
|
||||
PlanningStrategy.REACT_ROUTER.value,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
# old style, use tool name as key
|
||||
@ -217,7 +229,7 @@ class DatasetConfigManager:
|
||||
|
||||
has_datasets = True
|
||||
|
||||
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
||||
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from configs import dify_config
|
||||
@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
@ -35,9 +39,8 @@ class AppQueueManager:
|
||||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
|
||||
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
@ -79,9 +82,21 @@ class AppQueueManager:
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom):
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
Remove the task belong cache key once listening is finished.
|
||||
"""
|
||||
try:
|
||||
redis_client.delete(self._task_belong_cache_key)
|
||||
except RedisError:
|
||||
logger.exception(
|
||||
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
|
||||
)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
|
||||
@ -107,7 +107,6 @@ class MessageCycleManager:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
|
||||
@ -15,7 +15,7 @@ class GPT2Tokenizer:
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -196,15 +196,15 @@ def jsonable_encoder(
|
||||
return encoder(obj)
|
||||
|
||||
try:
|
||||
data = dict(obj)
|
||||
data = dict(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors: list[Exception] = []
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
data = vars(obj) # type: ignore
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors) from e
|
||||
raise ValueError(str(errors)) from e
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
|
||||
@ -3,7 +3,8 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
||||
@ -155,7 +155,10 @@ class OpsTraceManager:
|
||||
if key in tracing_config:
|
||||
if "*" in tracing_config[key]:
|
||||
# If the key contains '*', retain the original value from the current config
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
if current_trace_config:
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
else:
|
||||
new_config[key] = tracing_config[key]
|
||||
else:
|
||||
# Otherwise, encrypt the key
|
||||
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
@ -424,7 +425,23 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
call = self.weave_client.create_call(
|
||||
op=run_data.op,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
||||
exception = Exception(run_data.exception) if run_data.exception else None
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
|
||||
else:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
@ -106,7 +106,9 @@ class RetrievalService:
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
@ -143,6 +145,40 @@ class RetrievalService:
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
unique_documents = []
|
||||
seen_doc_ids = set()
|
||||
|
||||
for document in documents:
|
||||
# For dify provider documents, use doc_id for deduplication
|
||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
||||
doc_id = document.metadata["doc_id"]
|
||||
if doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
# If duplicate, keep the one with higher score
|
||||
elif "score" in document.metadata:
|
||||
# Find existing document with same doc_id and compare scores
|
||||
for i, existing_doc in enumerate(unique_documents):
|
||||
if (
|
||||
existing_doc.metadata
|
||||
and existing_doc.metadata.get("doc_id") == doc_id
|
||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
||||
):
|
||||
unique_documents[i] = document
|
||||
break
|
||||
else:
|
||||
# For non-dify documents, use content-based deduplication
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
|
||||
return unique_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self._run(query)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
||||
@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool):
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
result = self.retrieval_tool.run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
|
||||
@ -2,6 +2,7 @@ import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser:
|
||||
if "allOf" in prop_dict:
|
||||
del prop_dict["allOf"]
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser:
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
converted_openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser:
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
converted_openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser:
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
if "definitions" in swagger:
|
||||
for name, definition in swagger["definitions"].items():
|
||||
converted_openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
return converted_openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
|
||||
@ -191,11 +191,22 @@ class VariablePool(BaseModel):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
||||
"""
|
||||
Get a nested attribute from a dictionary-like object.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-like object to search.
|
||||
attr: The key to look up.
|
||||
|
||||
Returns:
|
||||
Segment | None:
|
||||
The corresponding Segment built from the attribute value if the key exists,
|
||||
otherwise None.
|
||||
"""
|
||||
if not isinstance(obj, dict) or attr not in obj:
|
||||
return None
|
||||
return obj.get(attr)
|
||||
return variable_factory.build_segment(obj.get(attr))
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
|
||||
@ -20,6 +20,7 @@ class ModelInvokeCompletedEvent(NodeEventBase):
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
structured_output: dict | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(NodeEventBase):
|
||||
|
||||
@ -87,7 +87,7 @@ class Executor:
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
@ -349,11 +349,10 @@ class Executor:
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"ssl_verify": self.ssl_verify,
|
||||
"follow_redirects": True,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
|
||||
@ -165,6 +165,8 @@ class HttpRequestNode(Node):
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
|
||||
@ -83,7 +83,7 @@ class IfElseNode(Node):
|
||||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
|
||||
@ -10,6 +10,8 @@ from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
@ -217,6 +219,13 @@ class IterationNode(Node):
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Sync conversation variables after each iteration completes
|
||||
self._sync_conversation_variables_from_snapshot(
|
||||
self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
@ -235,7 +244,10 @@ class IterationNode(Node):
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
@ -252,7 +264,7 @@ class IterationNode(Node):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
@ -264,6 +276,9 @@ class IterationNode(Node):
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
@ -288,7 +303,7 @@ class IterationNode(Node):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -307,8 +322,17 @@ class IterationNode(Node):
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
return (
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
@ -430,6 +454,23 @@ class IterationNode(Node):
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
current_keys = set(parent_conversations.keys())
|
||||
snapshot_keys = set(snapshot.keys())
|
||||
|
||||
for removed_key in current_keys - snapshot_keys:
|
||||
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
|
||||
|
||||
for name, variable in snapshot.items():
|
||||
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
|
||||
|
||||
def _append_iteration_info_to_event(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
|
||||
@ -136,6 +136,11 @@ class KnowledgeIndexNode(Node):
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
doc_id_value = document.id
|
||||
ds_id_value = dataset.id
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
@ -161,16 +166,16 @@ class KnowledgeIndexNode(Node):
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
@ -182,13 +187,13 @@ class KnowledgeIndexNode(Node):
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
"document_id": doc_id_value,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
|
||||
@ -107,7 +107,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
|
||||
@ -161,6 +161,8 @@ class ListOperatorNode(Node):
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
elif isinstance(condition.value, bool):
|
||||
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
|
||||
@ -46,7 +46,7 @@ class LLMFileSaver(tp.Protocol):
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
@ -56,7 +56,7 @@ class LLMFileSaver(tp.Protocol):
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
@ -27,6 +27,7 @@ from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
@ -134,7 +135,7 @@ class LLMNode(Node):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -172,6 +173,7 @@ class LLMNode(Node):
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
result_text = ""
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
@ -285,6 +287,13 @@ class LLMNode(Node):
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=event.structured_output)
|
||||
if event.structured_output
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
@ -1060,7 +1069,7 @@ class LLMNode(Node):
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
*,
|
||||
invoke_result: LLMResult,
|
||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
@ -1091,6 +1100,8 @@ class LLMNode(Node):
|
||||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if enabled
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -179,6 +179,6 @@ CHAT_EXAMPLE = [
|
||||
"required": ["food"],
|
||||
},
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
|
||||
},
|
||||
]
|
||||
|
||||
@ -68,7 +68,7 @@ class QuestionClassifierNode(Node):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -111,9 +111,9 @@ class QuestionClassifierNode(Node):
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_model=node_data.model,
|
||||
)
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -9,7 +9,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node):
|
||||
|
||||
@ -416,4 +416,8 @@ class WorkflowEntry:
|
||||
|
||||
# append variable and value to variable pool
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
# In single run, the input_value is set as the LLM's structured output value within the variable_pool.
|
||||
if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output":
|
||||
input_value = {variable_key_list[1]: input_value}
|
||||
variable_key_list = variable_key_list[0:1]
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
Reference in New Issue
Block a user