Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-27 17:38:10 +08:00
133 changed files with 2051 additions and 1282 deletions

View File

@ -369,11 +369,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
response_finish = self._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response_finish:
yield response_finish

View File

@ -114,9 +114,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
}
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self.total_tokens: int = 0
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""

View File

@ -67,8 +67,6 @@ class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]]
def _handle_workflow_run_start(
self,
@ -313,33 +311,11 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = event.execution_metadata
if self._wip_workflow_agent_logs.get(workflow_node_execution.id):
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
workflow_node_execution.id, []
)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
}
)
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
@ -372,35 +348,9 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata_dict = event.execution_metadata
if self._wip_workflow_agent_logs.get(workflow_node_execution.id):
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
workflow_node_execution.id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
@ -889,41 +839,10 @@ class WorkflowCycleManage:
:param event: agent log event
:return:
"""
node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
if not node_execution:
raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
node_execution_id = node_execution.id
original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
# try to find the log with the same id
for log in original_agent_logs:
if log.id == event.id:
# update the log
log.status = event.status
log.error = event.error
log.data = event.data
break
else:
# append the log
original_agent_logs.append(
AgentLogStreamResponse.Data(
id=event.id,
parent_id=event.parent_id,
node_execution_id=node_execution_id,
error=event.error,
status=event.status,
data=event.data,
label=event.label,
)
)
self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=node_execution_id,
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,

View File

@ -31,7 +31,6 @@ from core.rag.splitter.fixed_text_splitter import (
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage

View File

@ -1,770 +0,0 @@
import copy
import json
import logging
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
return self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
else:
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = self._get_base_model_name(credentials)
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid")
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else:
# text completion model, do not support tool calling
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials, content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=1,
max_completion_tokens=20,
stream=False,
)
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
else:
# text completion model
client.completions.create(
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
if delta.finish_reason is not None:
# calculate num tokens
if chunk.usage:
# transform usage
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
),
)
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False
if model.startswith("o1"):
if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Clear illegal prompt messages for OpenAI API
:param model: model name
:param prompt_messages: prompt messages
:return: cleaned prompt messages
"""
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
if model in checklist:
# count how many user messages are there
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
if user_message_count > 1:
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
item.data
if item.type == PromptMessageContentType.TEXT
else "[IMAGE]"
if item.type == PromptMessageContentType.IMAGE
else ""
for item in prompt_message.content
]
)
if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)
new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages
return prompt_messages
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
assistant_message = response.choices[0].message
assistant_message_tool_calls = assistant_message.tool_calls
# extract tool calls from response
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
index = 0
full_assistant_content = ""
real_model = model
system_fingerprint = None
completion = ""
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
if delta.delta is None:
continue
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.finish_reason is None and not delta.delta.content:
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
full_assistant_content += delta.delta.content or ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content or ""
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
index += 1
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
),
)
@staticmethod
def _update_tool_calls(
tool_calls: list[AssistantPromptMessage.ToolCall],
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = (
response_tool_call.function.name or tool_calls[index].function.name
)
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
# message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
# fix azure when enable json schema cant process content = "" in assistant fix with None
if not message.content:
message_dict["content"] = None
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
if model.startswith("gpt-35-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or "o1" in model:
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name

View File

@ -1,5 +1,5 @@
import re
from typing import Optional
from typing import Optional, cast
class JiebaKeywordTableHandler:
@ -8,18 +8,20 @@ class JiebaKeywordTableHandler:
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba # type: ignore
import jieba.analyse # type: ignore
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(keywords))
return set(self._expand_tokens_with_subtokens(set(keywords)))
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""

View File

@ -1,355 +0,0 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import deepcopy
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeFrom,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
ToolRuntimeImageVariable,
ToolRuntimeVariable,
ToolRuntimeVariablePool,
)
from core.tools.tool_file_manager import ToolFileManager
if TYPE_CHECKING:
from core.file.models import File
class Tool(BaseModel, ABC):
identity: Optional[ToolIdentity] = None
parameters: Optional[list[ToolParameter]] = None
description: Optional[ToolDescription] = None
is_team_authorization: bool = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
class Runtime(BaseModel):
"""
Meta data of a tool call processing
"""
def __init__(self, **data: Any):
super().__init__(**data)
if not self.runtime_parameters:
self.runtime_parameters = {}
tenant_id: Optional[str] = None
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None
runtime_parameters: Optional[dict[str, Any]] = None
runtime: Optional[Runtime] = None
variables: Optional[ToolRuntimeVariablePool] = None
def __init__(self, **data: Any):
super().__init__(**data)
class VariableKey(StrEnum):
IMAGE = "image"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
CUSTOM = "custom"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.model_copy() if self.description else None,
runtime=Tool.Runtime(**runtime),
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
"""
load variables from database
:param conversation_id: the conversation id
"""
self.variables = variables
def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
set an image variable
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
def set_text_variable(self, variable_name: str, text: str) -> None:
"""
set a text variable
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_text(self.identity.name, variable_name, text)
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
"""
get a variable
:param name: the name of the variable
:return: the variable
"""
if not self.variables:
return None
if isinstance(name, Enum):
name = name.value
for variable in self.variables.pool:
if variable.name == name:
return variable
return None
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
"""
get the default image variable
:return: the image variable
"""
if not self.variables:
return None
return self.get_variable(self.VariableKey.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
get a variable file
:param name: the name of the variable
:return: the variable file
"""
variable = self.get_variable(name)
if not variable:
return None
if not isinstance(variable, ToolRuntimeImageVariable):
return None
message_file_id = variable.value
# get file binary
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
if not file_binary:
return None
return file_binary[0]
def list_variables(self) -> list[ToolRuntimeVariable]:
"""
list all variables
:return: the variables
"""
if not self.variables:
return []
return self.variables.pool
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
"""
list all image variables
:return: the image variables
"""
if not self.variables:
return []
result = []
for variable in self.variables.pool:
if variable.name.startswith(self.VariableKey.IMAGE.value):
result.append(variable)
return result
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
if self.runtime is None:
return []
if self.runtime.runtime_parameters:
# Convert Mapping to dict before updating
tool_parameters = dict(tool_parameters)
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
if not isinstance(result, list):
result = [result]
if not all(isinstance(message, ToolInvokeMessage) for message in result):
raise ValueError(
f"Invalid return type from {self.__class__.__name__}._invoke method. "
"Expected ToolInvokeMessage or list of ToolInvokeMessage."
)
return result
def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result: dict[str, Any] = deepcopy(dict(tool_parameters))
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
return result
@abstractmethod
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
pass
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials
:param credentials: the credentials
:param parameters: the parameters
:param format_only: only return the formatted
"""
pass
def get_runtime_parameters(self) -> list[ToolParameter]:
"""
get the runtime parameters
interface for developer to dynamic change the parameters of a tool depends on the variables pool
:return: the runtime parameters
"""
return self.parameters or []
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
get all runtime parameters
:return: all runtime parameters
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters()
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
return parameters
def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as)
def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
"""
create a text message
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=blob,
meta=meta or {},
save_as=save_as,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object)

View File

@ -613,10 +613,10 @@ class Graph(BaseModel):
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids:
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id2 in merge_branch_node_ids:
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
@ -189,10 +189,12 @@ class ToolNode(BaseNode[ToolNodeData]):
conversation_id=None,
)
files: list[File] = []
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLog] = []
variables: dict[str, Any] = {}
for message in message_stream:
@ -239,14 +241,16 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
File(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
extension=None,
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT: