merge main

This commit is contained in:
Joel
2024-10-25 11:26:41 +08:00
parent bdb990eb90
commit 3e011109ad
124 changed files with 9664 additions and 0 deletions

View File

@ -0,0 +1,117 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:return:
"""
result = self._run()
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
else:
yield from result
@classmethod
def extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], config: dict
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=node_data
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {}
@property
def node_type(self) -> NodeType:
"""
Get node type
:return:
"""
return self._node_type

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

View File

@ -0,0 +1,343 @@
import json
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class HttpExecutorResponse:
headers: dict[str, str]
response: httpx.Response
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
@property
def is_file(self) -> bool:
"""
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ["image", "audio", "video"]
return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str:
return self.headers.get("content-type", "")
def extract_file(self) -> tuple[str, bytes]:
"""
extract file from response if content type is file related
"""
if self.is_file:
return self.get_content_type(), self.body
return "", b""
@property
def content(self) -> str:
if isinstance(self.response, httpx.Response):
return self.response.text
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def body(self) -> bytes:
if isinstance(self.response, httpx.Response):
return self.response.content
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def status_code(self) -> int:
if isinstance(self.response, httpx.Response):
return self.response.status_code
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def size(self) -> int:
return len(self.body)
@property
def readable_size(self) -> str:
if self.size < 1024:
return f"{self.size} bytes"
elif self.size < 1024 * 1024:
return f"{(self.size / 1024):.2f} KB"
else:
return f"{(self.size / 1024 / 1024):.2f} MB"
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeAuthorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
timeout: HttpRequestNodeTimeout
def __init__(
self,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: Optional[VariablePool] = None,
):
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.timeout = timeout
self.params = {}
self.headers = {}
self.body = None
self.files = None
# init template
self.variable_selectors = []
self._init_template(node_data, variable_pool)
@staticmethod
def _is_json_body(body: HttpRequestNodeBody):
"""
check if body is json
"""
if body and body.type == "json" and body.data:
try:
json.loads(body.data)
return True
except:
return False
return False
@staticmethod
def _to_dict(convert_text: str):
"""
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
"""
kv_paris = convert_text.split("\n")
result = {}
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(":", maxsplit=1)
if len(kv) == 1:
k, v = kv[0], ""
else:
k, v = kv
result[k.strip()] = v
return result
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
self.params = self._to_dict(params)
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
self.headers = self._to_dict(headers)
# extract all template in body
body_data_variable_selectors = []
if node_data.body:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ""
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
if node_data.body.type == "json" and not content_type_is_set:
self.headers["Content-Type"] = "application/json"
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
body = self._to_dict(body_data)
if node_data.body.type == "form-data":
self.files = {k: ("", v) for k, v in body.items()}
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
else:
self.body = urlencode(body)
elif node_data.body.type in {"json", "raw-text"}:
self.body = body_data
elif node_data.body.type == "none":
self.body = ""
self.variable_selectors = (
server_url_variable_selectors
+ params_variable_selectors
+ headers_variable_selectors
+ body_data_variable_selectors
)
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == "api-key":
if self.authorization.config is None:
raise ValueError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
if self.authorization.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if self.authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif self.authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f"Invalid response type {type(response)}")
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
if executor_response.is_file
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
)
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
kwargs = {
"url": self.server_url,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
}
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else:
raise ValueError(f"Invalid http method {self.method}")
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
if self.params:
server_url += f"?{urlencode(self.params)}"
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
headers = self._assembling_headers()
for k, v in headers.items():
# get authorization header
if self.authorization.type == "api-key":
authorization_header = "Authorization"
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
if k.lower() == authorization_header.lower():
raw_request += f'{k}: {"*" * len(v)}\n'
continue
raw_request += f"{k}: {v}\n"
raw_request += "\n"
# if files, use multipart/form-data with boundary
if self.files:
boundary = self.boundary
raw_request += f"--{boundary}"
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f"{v[1]}\n"
raw_request += f"--{boundary}"
raw_request += "--"
else:
raw_request += self.body or ""
return raw_request
def _format_template(
self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
) -> tuple[str, list[VariableSelector]]:
"""
format template
"""
variable_template_parser = VariableTemplateParser(template=template)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
variable = variable_pool.get_any(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
if escape_quotes and isinstance(variable, str):
value = variable.replace('"', '\\"').replace("\n", "\\n")
else:
value = variable
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors
else:
return template, variable_selectors

View File

@ -0,0 +1,165 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any, cast
from configs import dify_config
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
)
class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
return {
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
"body": {"type": "none"},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
}
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
# invoke http executor
response = http_executor.invoke()
except Exception as e:
process_data = {}
if http_executor:
process_data = {
"request": http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.content if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_raw_request(),
},
)
@staticmethod
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
variable_selectors = http_executor.variable_selectors
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
"""
Extract files from response
"""
files = []
mimetype, file_binary = response.extract_file()
if mimetype:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
)
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=mimetype,
)
)
return files

View File

@ -0,0 +1,774 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None
process_data = None
try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
# merge inputs
inputs.update(jinja_inputs)
node_inputs = {}
# fetch files
files = self._fetch_files(node_data, variable_pool)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
if context:
node_inputs["#context#"] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config,
)
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
def _invoke_llm(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
)
# handle invoke result
generator = self._handle_invoke_result(invoke_result=invoke_result)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
model = None
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
def _transform_chat_messages(
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
:param messages: chat messages
:return:
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == "jinja2" and message.jinja2_text:
message.text = message.jinja2_text
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
variables = {}
if not node_data.prompt_config:
return variables
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
def parse_dict(d: dict) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
return d["content"]
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)
if isinstance(value, str):
value = value
elif isinstance(value, list):
result = ""
for item in value:
if isinstance(item, dict):
result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else:
result += str(item)
result += "\n"
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
else:
value = str(value)
variables[variable] = value
return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
inputs = {}
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
if not files:
return []
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
elif isinstance(context_value, list):
context_str = ""
original_retriever_resource = []
for item in context_value:
if isinstance(item, str):
context_str += item + "\n"
else:
if "content" not in item:
raise ValueError(f"Invalid context structure: {item}")
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
and context_dict["metadata"]["_source"] == "knowledge"
):
metadata = context_dict.get("metadata", {})
source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("document_data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
}
return source
return None
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data_model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_prompt_messages(
self,
node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
query=query or "",
files=files,
context=context,
memory_config=node_data.memory,
memory=memory,
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
vision_enabled = node_data.vision.enabled
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
continue
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if (
vision_enabled
and content_item.type == PromptMessageContentType.IMAGE
and isinstance(content_item, ImagePromptMessageContent)
):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
prompt_message_content.append(content_item)
elif content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
raise ValueError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if "gpt-4" in model_instance.model:
used_quota = 20
else:
used_quota = 1
if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update({"quota_used": Provider.quota_used + used_quota})
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
enable_jinja = True
break
else:
if prompt_template.edition_type == "jinja2":
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {
"prompt_templates": {
"chat_model": {
"prompts": [
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
]
},
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic",
},
"stop": ["Human:"],
},
}
},
}