mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
use model_validate (#26182)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@ -40,7 +40,7 @@ class AgentConfigManager:
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
|
||||
@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
||||
):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
return self
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
|
||||
@ -131,7 +131,7 @@ class CodeExecutor:
|
||||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
response_code = CodeExecutionResponse(**response_data)
|
||||
response_code = CodeExecutionResponse.model_validate(response_data)
|
||||
|
||||
if response_code.data.error:
|
||||
raise CodeExecutionError(response_code.data.error)
|
||||
|
||||
@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@ -357,14 +357,16 @@ class IndexingRunner:
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
},
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
@ -378,14 +380,16 @@ class IndexingRunner:
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
|
||||
@ -294,7 +294,7 @@ class ClientSession(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
argument=types.CompletionArgument.model_validate(argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@ -9,7 +9,8 @@ class I18nObject(BaseModel):
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@ -46,10 +46,11 @@ class FormOption(BaseModel):
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
|
||||
@ -269,17 +269,17 @@ class ModelProviderFactory:
|
||||
}
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel(**init_params) # type: ignore
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(**init_params) # type: ignore
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel(**init_params) # type: ignore
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(**init_params) # type: ignore
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel(**init_params) # type: ignore
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel(**init_params) # type: ignore
|
||||
return TTSModel.model_validate(init_params)
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
|
||||
@ -51,7 +51,7 @@ class ApiModeration(Moderation):
|
||||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult(**result)
|
||||
return ModerationInputsResult.model_validate(result)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
@ -67,7 +67,7 @@ class ApiModeration(Moderation):
|
||||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult(**result)
|
||||
return ModerationOutputsResult.model_validate(result)
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
|
||||
@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER.value:
|
||||
v[i] = UserPromptMessage(**v[i])
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
|
||||
v[i] = AssistantPromptMessage(**v[i])
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
|
||||
v[i] = SystemPromptMessage(**v[i])
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL.value:
|
||||
v[i] = ToolPromptMessage(**v[i])
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage(**v[i])
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@ -104,13 +104,13 @@ class BasePluginClient:
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
yield type_(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict | None = None,
|
||||
@ -120,13 +120,13 @@ class BasePluginClient:
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@ -140,22 +140,22 @@ class BasePluginClient:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logger.exception(msg)
|
||||
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logger.exception(msg)
|
||||
logger.exception("Failed to request plugin daemon, url: %s", path)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logger.exception(msg)
|
||||
@ -163,7 +163,7 @@ class BasePluginClient:
|
||||
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise ValueError(f"{rep.message}, code: {rep.code}")
|
||||
|
||||
@ -178,7 +178,7 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
@ -189,7 +189,7 @@ class BasePluginClient:
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
try:
|
||||
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
# TODO modify this when line_data has code and message
|
||||
try:
|
||||
@ -204,7 +204,7 @@ class BasePluginClient:
|
||||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
|
||||
@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
|
||||
self._get_local_file_datasource_provider()
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type=LLMResultChunk,
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type=PluginLLMNumTokensResponse,
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type=TextEmbeddingResult,
|
||||
type_=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type=PluginTextEmbeddingNumTokensResponse,
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type=RerankResult,
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type=PluginVoicesResponse,
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type=PluginBasicBooleanResponse,
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
||||
@ -134,7 +134,7 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
|
||||
@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
|
||||
website_info: WebsiteInfo | None = None
|
||||
document_model: str | None = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if process_rule.get("mode") == "automatic":
|
||||
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
|
||||
rules = Rule(**automatic_rule)
|
||||
rules = Rule.model_validate(automatic_rule)
|
||||
else:
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
# Split the text documents into nodes.
|
||||
if not rules.segmentation:
|
||||
raise ValueError("No segmentation found in rules.")
|
||||
|
||||
@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
all_documents: list[Document] = []
|
||||
if rules.parent_mode == ParentMode.PARAGRAPH:
|
||||
# Split the text documents into nodes.
|
||||
@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_documents = document.children
|
||||
if child_documents:
|
||||
formatted_child_documents = [
|
||||
Document(**child_document.model_dump()) for child_document in child_documents
|
||||
Document.model_validate(child_document.model_dump()) for child_document in child_documents
|
||||
]
|
||||
vector.create(formatted_child_documents)
|
||||
|
||||
@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
return child_nodes
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
metadata = {
|
||||
@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
vector.create(all_child_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
|
||||
@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
rules = Rule.model_validate(process_rule.get("rules"))
|
||||
splitter = self._get_splitter(
|
||||
processing_rule_mode=process_rule.get("mode"),
|
||||
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
|
||||
@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
metadata = {
|
||||
@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
|
||||
@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tools.append(
|
||||
assistant_tool_class(
|
||||
provider=provider,
|
||||
entity=ToolEntity(**tool),
|
||||
entity=ToolEntity.model_validate(tool),
|
||||
runtime=ToolRuntime(tenant_id=""),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _populate_missing_locales(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
|
||||
@ -1008,7 +1008,7 @@ class ToolManager:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
|
||||
@ -105,10 +105,10 @@ class RedisChannel:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@ -16,7 +16,7 @@ class EndNode(Node):
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
self._node_data = EndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -18,7 +18,7 @@ class IterationStartNode(Node):
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
self._node_data = IterationStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -41,7 +41,7 @@ class ListOperatorNode(Node):
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopEndNode(Node):
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
self._node_data = LoopEndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopStartNode(Node):
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
self._node_data = LoopStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -16,7 +16,7 @@ class StartNode(Node):
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
self._node_data = StartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
Reference in New Issue
Block a user