mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
import enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
|
||||
@ -225,11 +225,11 @@ class Tenant(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def custom_config_dict(self) -> dict:
|
||||
def custom_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.custom_config) if self.custom_config else {}
|
||||
|
||||
@custom_config_dict.setter
|
||||
def custom_config_dict(self, value: dict):
|
||||
def custom_config_dict(self, value: dict[str, Any]) -> None:
|
||||
self.custom_config = json.dumps(value)
|
||||
|
||||
|
||||
|
||||
@ -318,7 +318,7 @@ class DatasetProcessRule(Base):
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
|
||||
}
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@ -327,7 +327,7 @@ class DatasetProcessRule(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def rules_dict(self):
|
||||
def rules_dict(self) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(self.rules) if self.rules else None
|
||||
except JSONDecodeError:
|
||||
@ -427,7 +427,7 @@ class Document(Base):
|
||||
def data_source_info_dict(self) -> dict[str, Any]:
|
||||
if self.data_source_info:
|
||||
try:
|
||||
data_source_info_dict = json.loads(self.data_source_info)
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
except JSONDecodeError:
|
||||
data_source_info_dict = {}
|
||||
|
||||
@ -435,10 +435,10 @@ class Document(Base):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def data_source_detail_dict(self):
|
||||
def data_source_detail_dict(self) -> dict[str, Any]:
|
||||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict = json.loads(self.data_source_info)
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
@ -457,7 +457,8 @@ class Document(Base):
|
||||
}
|
||||
}
|
||||
elif self.data_source_type in {"notion_import", "website_crawl"}:
|
||||
return json.loads(self.data_source_info)
|
||||
result: dict[str, Any] = json.loads(self.data_source_info)
|
||||
return result
|
||||
return {}
|
||||
|
||||
@property
|
||||
@ -503,7 +504,7 @@ class Document(Base):
|
||||
return self.updated_at
|
||||
|
||||
@property
|
||||
def doc_metadata_details(self):
|
||||
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
@ -513,9 +514,9 @@ class Document(Base):
|
||||
)
|
||||
.all()
|
||||
)
|
||||
metadata_list = []
|
||||
metadata_list: list[dict[str, Any]] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict = {
|
||||
metadata_dict: dict[str, Any] = {
|
||||
"id": metadata.id,
|
||||
"name": metadata.name,
|
||||
"type": metadata.type,
|
||||
@ -529,13 +530,13 @@ class Document(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def process_rule_dict(self):
|
||||
if self.dataset_process_rule_id:
|
||||
def process_rule_dict(self) -> dict[str, Any] | None:
|
||||
if self.dataset_process_rule_id and self.dataset_process_rule:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
return None
|
||||
|
||||
def get_built_in_fields(self):
|
||||
built_in_fields = []
|
||||
def get_built_in_fields(self) -> list[dict[str, Any]]:
|
||||
built_in_fields: list[dict[str, Any]] = []
|
||||
built_in_fields.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
@ -578,7 +579,7 @@ class Document(Base):
|
||||
)
|
||||
return built_in_fields
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@ -624,13 +625,13 @@ class Document(Base):
|
||||
"data_source_info_dict": self.data_source_info_dict,
|
||||
"average_segment_length": self.average_segment_length,
|
||||
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
|
||||
"dataset": self.dataset.to_dict() if self.dataset else None,
|
||||
"dataset": None, # Dataset class doesn't have a to_dict method
|
||||
"segment_count": self.segment_count,
|
||||
"hit_count": self.hit_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
def from_dict(cls, data: dict[str, Any]):
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
@ -743,46 +744,48 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self):
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule.mode == "hierarchical":
|
||||
rules = Rule(**process_rule.rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
def child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self):
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule.mode == "hierarchical":
|
||||
rules = Rule(**process_rule.rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
@property
|
||||
def sign_content(self):
|
||||
def sign_content(self) -> str:
|
||||
return self.get_sign_content()
|
||||
|
||||
def get_sign_content(self):
|
||||
signed_urls = []
|
||||
def get_sign_content(self) -> str:
|
||||
signed_urls: list[tuple[int, int, str]] = []
|
||||
text = self.content
|
||||
|
||||
# For data before v0.10.0
|
||||
@ -943,17 +946,22 @@ class DatasetKeywordTable(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def keyword_table_dict(self):
|
||||
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
|
||||
class SetDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(object_hook=self.object_hook, *args, **kwargs)
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def object_hook(dct: Any) -> Any:
|
||||
if isinstance(dct, dict):
|
||||
result: dict[str, Any] = {}
|
||||
items = cast(dict[str, Any], dct).items()
|
||||
for keyword, node_idxs in items:
|
||||
if isinstance(node_idxs, list):
|
||||
result[keyword] = set(cast(list[Any], node_idxs))
|
||||
else:
|
||||
result[keyword] = node_idxs
|
||||
return result
|
||||
return dct
|
||||
|
||||
def object_hook(self, dct):
|
||||
if isinstance(dct, dict):
|
||||
for keyword, node_idxs in dct.items():
|
||||
if isinstance(node_idxs, list):
|
||||
dct[keyword] = set(node_idxs)
|
||||
return dct
|
||||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
@ -1079,7 +1087,7 @@ class ExternalKnowledgeApis(Base):
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@ -1092,14 +1100,14 @@ class ExternalKnowledgeApis(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def settings_dict(self):
|
||||
def settings_dict(self) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(self.settings) if self.settings else None
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def dataset_bindings(self):
|
||||
def dataset_bindings(self) -> list[dict[str, Any]]:
|
||||
external_knowledge_bindings = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
@ -1107,7 +1115,7 @@ class ExternalKnowledgeApis(Base):
|
||||
)
|
||||
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
|
||||
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
dataset_bindings = []
|
||||
dataset_bindings: list[dict[str, Any]] = []
|
||||
for dataset in datasets:
|
||||
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
@ -18,7 +18,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from libs.helper import generate_string
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base
|
||||
@ -98,7 +98,7 @@ class App(Base):
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def desc_or_prompt(self):
|
||||
def desc_or_prompt(self) -> str:
|
||||
if self.description:
|
||||
return self.description
|
||||
else:
|
||||
@ -109,12 +109,12 @@ class App(Base):
|
||||
return ""
|
||||
|
||||
@property
|
||||
def site(self):
|
||||
def site(self) -> Optional["Site"]:
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
def app_model_config(self) -> Optional["AppModelConfig"]:
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
||||
@ -130,11 +130,11 @@ class App(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def api_base_url(self):
|
||||
def api_base_url(self) -> str:
|
||||
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
def tenant(self) -> Optional[Tenant]:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
@ -162,9 +162,8 @@ class App(Base):
|
||||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list:
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
# get agent mode tools
|
||||
@ -244,7 +243,7 @@ class App(Base):
|
||||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools = []
|
||||
deleted_tools: list[dict[str, str]] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
@ -277,7 +276,7 @@ class App(Base):
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
def tags(self) -> list["Tag"]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@ -293,7 +292,7 @@ class App(Base):
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self):
|
||||
def author_name(self) -> Optional[str]:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
@ -336,20 +335,20 @@ class AppModelConfig(Base):
|
||||
file_upload = mapped_column(sa.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
def app(self) -> Optional[App]:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> dict:
|
||||
def model_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.model) if self.model else {}
|
||||
|
||||
@property
|
||||
def suggested_questions_list(self) -> list:
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> dict:
|
||||
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
@ -357,19 +356,19 @@ class AppModelConfig(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> dict:
|
||||
def speech_to_text_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> dict:
|
||||
def text_to_speech_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> dict:
|
||||
def retriever_resource_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict:
|
||||
def annotation_reply_dict(self) -> dict[str, Any]:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
@ -392,11 +391,11 @@ class AppModelConfig(Base):
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict:
|
||||
def more_like_this_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> dict:
|
||||
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.sensitive_word_avoidance)
|
||||
if self.sensitive_word_avoidance
|
||||
@ -404,15 +403,15 @@ class AppModelConfig(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def external_data_tools_list(self) -> list[dict]:
|
||||
def external_data_tools_list(self) -> list[dict[str, Any]]:
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self):
|
||||
def user_input_form_list(self) -> list[dict[str, Any]]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict:
|
||||
def agent_mode_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.agent_mode)
|
||||
if self.agent_mode
|
||||
@ -420,17 +419,17 @@ class AppModelConfig(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict:
|
||||
def chat_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> dict:
|
||||
def completion_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict:
|
||||
def dataset_configs_dict(self) -> dict[str, Any]:
|
||||
if self.dataset_configs:
|
||||
dataset_configs: dict = json.loads(self.dataset_configs)
|
||||
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
@ -440,7 +439,7 @@ class AppModelConfig(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
def file_upload_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.file_upload)
|
||||
if self.file_upload
|
||||
@ -454,7 +453,7 @@ class AppModelConfig(Base):
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@ -548,7 +547,7 @@ class RecommendedApp(Base):
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
def app(self) -> Optional[App]:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@ -572,12 +571,12 @@ class InstalledApp(Base):
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
def app(self) -> Optional[App]:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
def tenant(self) -> Optional[Tenant]:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
@ -624,7 +623,7 @@ class Conversation(Base):
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
summary = mapped_column(sa.Text)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
|
||||
introduction = mapped_column(sa.Text)
|
||||
system_instruction = mapped_column(sa.Text)
|
||||
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
@ -654,7 +653,7 @@ class Conversation(Base):
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
def inputs(self) -> dict[str, Any]:
|
||||
inputs = self._inputs.copy()
|
||||
|
||||
# Convert file mapping to File object
|
||||
@ -662,22 +661,39 @@ class Conversation(Base):
|
||||
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
|
||||
from factories import file_factory
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
inputs[key] = []
|
||||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value_dict["tool_file_id"] = value_dict["related_id"]
|
||||
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value_dict["upload_file_id"] = value_dict["related_id"]
|
||||
tenant_id = cast(str, value_dict.get("tenant_id", ""))
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
|
||||
elif isinstance(value, list):
|
||||
value_list = cast(list[Any], value)
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item_dict["tool_file_id"] = item_dict["related_id"]
|
||||
elif item_dict["transfer_method"] in [
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
]:
|
||||
item_dict["upload_file_id"] = item_dict["related_id"]
|
||||
tenant_id = cast(str, item_dict.get("tenant_id", ""))
|
||||
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
|
||||
inputs[key] = file_list
|
||||
|
||||
return inputs
|
||||
|
||||
@ -687,8 +703,10 @@ class Conversation(Base):
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
elif isinstance(v, list):
|
||||
v_list = cast(list[Any], v)
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
@ -828,7 +846,7 @@ class Conversation(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
def app(self) -> Optional[App]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
@ -842,7 +860,7 @@ class Conversation(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def from_account_name(self):
|
||||
def from_account_name(self) -> Optional[str]:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
if account:
|
||||
@ -851,10 +869,10 @@ class Conversation(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def in_debug_mode(self):
|
||||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@ -900,7 +918,7 @@ class Message(Base):
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
|
||||
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
message = mapped_column(sa.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
@ -927,28 +945,45 @@ class Message(Base):
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
def inputs(self) -> dict[str, Any]:
|
||||
inputs = self._inputs.copy()
|
||||
for key, value in inputs.items():
|
||||
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
|
||||
from factories import file_factory
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
):
|
||||
inputs[key] = []
|
||||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value_dict["tool_file_id"] = value_dict["related_id"]
|
||||
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value_dict["upload_file_id"] = value_dict["related_id"]
|
||||
tenant_id = cast(str, value_dict.get("tenant_id", ""))
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
|
||||
elif isinstance(value, list):
|
||||
value_list = cast(list[Any], value)
|
||||
if all(
|
||||
isinstance(item, dict)
|
||||
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
for item in value_list
|
||||
):
|
||||
file_list: list[File] = []
|
||||
for item in value_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item_dict["tool_file_id"] = item_dict["related_id"]
|
||||
elif item_dict["transfer_method"] in [
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
]:
|
||||
item_dict["upload_file_id"] = item_dict["related_id"]
|
||||
tenant_id = cast(str, item_dict.get("tenant_id", ""))
|
||||
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
|
||||
inputs[key] = file_list
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
@ -957,8 +992,10 @@ class Message(Base):
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
elif isinstance(v, list):
|
||||
v_list = cast(list[Any], v)
|
||||
if all(isinstance(item, File) for item in v_list):
|
||||
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
@ -1086,15 +1123,15 @@ class Message(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def in_debug_mode(self):
|
||||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
@property
|
||||
def message_metadata_dict(self) -> dict:
|
||||
def message_metadata_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self):
|
||||
def agent_thoughts(self) -> list["MessageAgentThought"]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
@ -1103,11 +1140,11 @@ class Message(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def retriever_resources(self):
|
||||
def retriever_resources(self) -> Any | list[Any]:
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
def message_files(self) -> list[dict[str, Any]]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
|
||||
@ -1115,7 +1152,7 @@ class Message(Base):
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
files = []
|
||||
files: list[File] = []
|
||||
for message_file in message_files:
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
|
||||
if message_file.upload_file_id is None:
|
||||
@ -1162,7 +1199,7 @@ class Message(Base):
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result = [
|
||||
result: list[dict[str, Any]] = [
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
@ -1179,7 +1216,7 @@ class Message(Base):
|
||||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@ -1203,7 +1240,7 @@ class Message(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
def from_dict(cls, data: dict[str, Any]) -> "Message":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
@ -1253,7 +1290,7 @@ class MessageFeedback(Base):
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
@ -1438,7 +1475,18 @@ class EndUser(Base, UserMixin):
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
external_user_id = mapped_column(String(255), nullable=True)
|
||||
name = mapped_column(String(255))
|
||||
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
_is_anonymous: Mapped[bool] = mapped_column(
|
||||
"is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true")
|
||||
)
|
||||
|
||||
@property
|
||||
def is_anonymous(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@is_anonymous.setter
|
||||
def is_anonymous(self, value: bool) -> None:
|
||||
self._is_anonymous = value
|
||||
|
||||
session_id: Mapped[str] = mapped_column()
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -1464,7 +1512,7 @@ class AppMCPServer(Base):
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n):
|
||||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
@ -1521,7 +1569,7 @@ class Site(Base):
|
||||
self._custom_disclaimer = value
|
||||
|
||||
@staticmethod
|
||||
def generate_code(n):
|
||||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
@ -1552,7 +1600,7 @@ class ApiToken(Base):
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key(prefix, n):
|
||||
def generate_api_key(prefix: str, n: int) -> str:
|
||||
while True:
|
||||
result = prefix + generate_string(n)
|
||||
if db.session.scalar(select(exists().where(ApiToken.token == result))):
|
||||
@ -1716,7 +1764,7 @@ class MessageAgentThought(Base):
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def files(self) -> list:
|
||||
def files(self) -> list[Any]:
|
||||
if self.message_files:
|
||||
return cast(list[Any], json.loads(self.message_files))
|
||||
else:
|
||||
@ -1727,32 +1775,32 @@ class MessageAgentThought(Base):
|
||||
return self.tool.split(";") if self.tool else []
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> dict:
|
||||
def tool_labels(self) -> dict[str, Any]:
|
||||
try:
|
||||
if self.tool_labels_str:
|
||||
return cast(dict, json.loads(self.tool_labels_str))
|
||||
return cast(dict[str, Any], json.loads(self.tool_labels_str))
|
||||
else:
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def tool_meta(self) -> dict:
|
||||
def tool_meta(self) -> dict[str, Any]:
|
||||
try:
|
||||
if self.tool_meta_str:
|
||||
return cast(dict, json.loads(self.tool_meta_str))
|
||||
return cast(dict[str, Any], json.loads(self.tool_meta_str))
|
||||
else:
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def tool_inputs_dict(self) -> dict:
|
||||
def tool_inputs_dict(self) -> dict[str, Any]:
|
||||
tools = self.tools
|
||||
try:
|
||||
if self.tool_input:
|
||||
data = json.loads(self.tool_input)
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for tool in tools:
|
||||
if tool in data:
|
||||
result[tool] = data[tool]
|
||||
@ -1768,12 +1816,12 @@ class MessageAgentThought(Base):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def tool_outputs_dict(self):
|
||||
def tool_outputs_dict(self) -> dict[str, Any]:
|
||||
tools = self.tools
|
||||
try:
|
||||
if self.observation:
|
||||
data = json.loads(self.observation)
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for tool in tools:
|
||||
if tool in data:
|
||||
result[tool] = data[tool]
|
||||
@ -1871,14 +1919,14 @@ class TraceAppConfig(Base):
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
@property
|
||||
def tracing_config_dict(self):
|
||||
def tracing_config_dict(self) -> dict[str, Any]:
|
||||
return self.tracing_config or {}
|
||||
|
||||
@property
|
||||
def tracing_config_str(self):
|
||||
def tracing_config_str(self) -> str:
|
||||
return json.dumps(self.tracing_config_dict)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
||||
@ -17,7 +17,7 @@ class ProviderType(Enum):
|
||||
SYSTEM = "system"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
def value_of(value: str) -> "ProviderType":
|
||||
for member in ProviderType:
|
||||
if member.value == value:
|
||||
return member
|
||||
@ -35,7 +35,7 @@ class ProviderQuotaType(Enum):
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
def value_of(value: str) -> "ProviderQuotaType":
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -26,15 +26,15 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# system level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthSystemClient(Base):
|
||||
class ToolOAuthSystemClient(TypeBase):
|
||||
__tablename__ = "tool_oauth_system_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
plugin_id = mapped_column(String(512), nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
@ -58,8 +58,8 @@ class ToolOAuthTenantClient(Base):
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict:
|
||||
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class BuiltinToolProvider(Base):
|
||||
@ -100,8 +100,8 @@ class BuiltinToolProvider(Base):
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return cast(dict, json.loads(self.encrypted_credentials))
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
|
||||
|
||||
|
||||
class ApiToolProvider(Base):
|
||||
@ -154,8 +154,8 @@ class ApiToolProvider(Base):
|
||||
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return dict(json.loads(self.credentials_str))
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
return dict[str, Any](json.loads(self.credentials_str))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
@ -299,9 +299,9 @@ class MCPToolProvider(Base):
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
try:
|
||||
return cast(dict, json.loads(self.encrypted_credentials)) or {}
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@ -341,12 +341,12 @@ class MCPToolProvider(Base):
|
||||
return mask_url(self.decrypted_server_url)
|
||||
|
||||
@property
|
||||
def decrypted_credentials(self) -> dict:
|
||||
def decrypted_credentials(self) -> dict[str, Any]:
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
provider_controller = MCPToolProviderController._from_db(self)
|
||||
provider_controller = MCPToolProviderController.from_db(self)
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -354,7 +354,7 @@ class MCPToolProvider(Base):
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.decrypt(self.credentials) # type: ignore
|
||||
return encrypter.decrypt(self.credentials)
|
||||
|
||||
|
||||
class ToolModelInvoke(Base):
|
||||
@ -422,11 +422,11 @@ class ToolConversationVariables(Base):
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def variables(self) -> Any:
|
||||
def variables(self):
|
||||
return json.loads(self.variables_str)
|
||||
|
||||
|
||||
class ToolFile(Base):
|
||||
class ToolFile(TypeBase):
|
||||
"""This table stores file metadata generated in workflows,
|
||||
not only files created by agent.
|
||||
"""
|
||||
@ -437,19 +437,19 @@ class ToolFile(Base):
|
||||
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# conversation id
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
# file key
|
||||
file_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# mime type
|
||||
mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original url
|
||||
original_url: Mapped[str] = mapped_column(String(2048), nullable=True)
|
||||
original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None)
|
||||
# name
|
||||
name: Mapped[str] = mapped_column(default="")
|
||||
# size
|
||||
|
||||
@ -1,29 +1,34 @@
|
||||
import enum
|
||||
from typing import Generic, TypeVar
|
||||
import uuid
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value.hex
|
||||
return value
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
@ -32,7 +37,7 @@ class StringUUID(TypeDecorator):
|
||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
||||
|
||||
|
||||
class EnumText(TypeDecorator, Generic[_E]):
|
||||
class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
impl = VARCHAR
|
||||
cache_ok = True
|
||||
|
||||
@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]):
|
||||
# leave some rooms for future longer enum values.
|
||||
self._length = max(max_enum_value_len, 20)
|
||||
|
||||
def process_bind_param(self, value: _E | str | None, dialect):
|
||||
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, self._enum_class):
|
||||
return value.value
|
||||
elif isinstance(value, str):
|
||||
self._enum_class(value)
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
|
||||
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
|
||||
self._enum_class(value)
|
||||
return value
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value, dialect) -> _E | None:
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"expected str, got {type(value)}")
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
|
||||
def compare_values(self, x, y):
|
||||
def compare_values(self, x: _E | None, y: _E | None) -> bool:
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
return x == y
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -232,7 +232,7 @@ class Workflow(Base):
|
||||
raise WorkflowDataError("nodes not found in workflow graph")
|
||||
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
assert isinstance(node_config, dict)
|
||||
@ -290,14 +290,14 @@ class Workflow(Base):
|
||||
return self._features
|
||||
|
||||
@features.setter
|
||||
def features(self, value: str) -> None:
|
||||
def features(self, value: str):
|
||||
self._features = value
|
||||
|
||||
@property
|
||||
def features_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list:
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
|
||||
# get start node from graph
|
||||
if not self.graph:
|
||||
return []
|
||||
@ -314,7 +314,7 @@ class Workflow(Base):
|
||||
variables: list[Any] = start_node.get("data", {}).get("variables", [])
|
||||
|
||||
if to_old_structure:
|
||||
old_structure_variables = []
|
||||
old_structure_variables: list[dict[str, Any]] = []
|
||||
for variable in variables:
|
||||
old_structure_variables.append({variable["type"]: variable})
|
||||
|
||||
@ -360,9 +360,7 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
# _environment_variables is guaranteed to be non-None due to server_default="{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
@ -376,17 +374,18 @@ class Workflow(Base):
|
||||
]
|
||||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(var):
|
||||
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||
return var
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
# Other variable types are not supported for environment variables
|
||||
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
|
||||
|
||||
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
|
||||
map(decrypt_func, results)
|
||||
)
|
||||
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
|
||||
decrypt_func(var) for var in results
|
||||
]
|
||||
return decrypted_results
|
||||
|
||||
@environment_variables.setter
|
||||
@ -414,7 +413,7 @@ class Workflow(Base):
|
||||
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
|
||||
|
||||
# encrypt secret variables value
|
||||
def encrypt_func(var):
|
||||
def encrypt_func(var: Variable) -> Variable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
else:
|
||||
@ -445,16 +444,14 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
def conversation_variables(self, value: Sequence[Variable]) -> None:
|
||||
def conversation_variables(self, value: Sequence[Variable]):
|
||||
self._conversation_variables = json.dumps(
|
||||
{var.name: var.model_dump() for var in value},
|
||||
ensure_ascii=False,
|
||||
@ -609,7 +606,7 @@ class WorkflowRun(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "WorkflowRun":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
@ -695,7 +692,8 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls): # noqa
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
@ -732,7 +730,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(), # type: ignore
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
|
||||
@ -820,15 +818,15 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
return json.loads(self.execution_metadata) if self.execution_metadata else {}
|
||||
|
||||
@property
|
||||
def extras(self):
|
||||
def extras(self) -> dict[str, Any]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
extras = {}
|
||||
extras: dict[str, Any] = {}
|
||||
if self.execution_metadata_dict:
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info = self.execution_metadata_dict["tool_info"]
|
||||
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
|
||||
extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_type=tool_info["provider_type"],
|
||||
@ -1068,7 +1066,7 @@ class ConversationVariable(Base):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
@ -1252,7 +1250,7 @@ class WorkflowDraftVariable(Base):
|
||||
# making this attribute harder to access from outside the class.
|
||||
__value: Segment | None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
The constructor of `WorkflowDraftVariable` is not intended for
|
||||
direct use outside this file. Its solo purpose is setup private state
|
||||
@ -1270,15 +1268,15 @@ class WorkflowDraftVariable(Base):
|
||||
self.__value = None
|
||||
|
||||
def get_selector(self) -> list[str]:
|
||||
selector = json.loads(self.selector)
|
||||
selector: Any = json.loads(self.selector)
|
||||
if not isinstance(selector, list):
|
||||
logger.error(
|
||||
"invalid selector loaded from database, type=%s, value=%s",
|
||||
type(selector),
|
||||
type(selector).__name__,
|
||||
self.selector,
|
||||
)
|
||||
raise ValueError("invalid selector.")
|
||||
return selector
|
||||
return cast(list[str], selector)
|
||||
|
||||
def _set_selector(self, value: list[str]):
|
||||
self.selector = json.dumps(value)
|
||||
@ -1288,7 +1286,7 @@ class WorkflowDraftVariable(Base):
|
||||
return self.build_segment_with_type(self.value_type, value)
|
||||
|
||||
@staticmethod
|
||||
def rebuild_file_types(value: Any) -> Any:
|
||||
def rebuild_file_types(value: Any):
|
||||
# NOTE(QuantumGhost): Temporary workaround for structured data handling.
|
||||
# By this point, `output` has been converted to dict by
|
||||
# `WorkflowEntry.handle_special_values`, so we need to
|
||||
@ -1301,15 +1299,17 @@ class WorkflowDraftVariable(Base):
|
||||
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
|
||||
if isinstance(value, dict):
|
||||
if not maybe_file_object(value):
|
||||
return value
|
||||
return cast(Any, value)
|
||||
return File.model_validate(value)
|
||||
elif isinstance(value, list) and value:
|
||||
first = value[0]
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return value
|
||||
return [File.model_validate(i) for i in value]
|
||||
return cast(Any, value)
|
||||
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return value
|
||||
return cast(Any, value)
|
||||
|
||||
@classmethod
|
||||
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
|
||||
|
||||
Reference in New Issue
Block a user