mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
@ -127,7 +127,7 @@ class AccountService:
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
@ -178,7 +178,7 @@ class AccountService:
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise AccountLoginError("Account is banned.")
|
||||
|
||||
if password and invite_token and account.password is None:
|
||||
@ -193,8 +193,8 @@ class AccountService:
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
db.session.commit()
|
||||
@ -246,10 +246,8 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = Account()
|
||||
account.email = email
|
||||
account.name = name
|
||||
|
||||
password_to_set = None
|
||||
salt_to_set = None
|
||||
if password:
|
||||
valid_password(password)
|
||||
|
||||
@ -261,14 +259,18 @@ class AccountService:
|
||||
password_hashed = hash_password(password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
password_to_set = base64_password_hashed
|
||||
salt_to_set = base64_salt
|
||||
|
||||
account.interface_language = interface_language
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
# Set timezone based on language
|
||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
account = Account(
|
||||
name=name,
|
||||
email=email,
|
||||
password=password_to_set,
|
||||
password_salt=salt_to_set,
|
||||
interface_language=interface_language,
|
||||
interface_theme=interface_theme,
|
||||
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
@ -355,7 +357,7 @@ class AccountService:
|
||||
@staticmethod
|
||||
def close_account(account: Account):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
account.status = AccountStatus.CLOSED
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@ -395,8 +397,8 @@ class AccountService:
|
||||
if ip_address:
|
||||
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
db.session.commit()
|
||||
|
||||
access_token = AccountService.get_account_jwt_token(account=account)
|
||||
@ -764,7 +766,7 @@ class AccountService:
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
return account
|
||||
@ -1028,7 +1030,7 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountRole.OWNER.value:
|
||||
if role == TenantAccountRole.OWNER:
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
@ -1313,7 +1315,7 @@ class RegisterService:
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
)
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
if open_id is not None and provider is not None:
|
||||
@ -1374,7 +1376,7 @@ class RegisterService:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
# Support resend invitation email when the account is pending status
|
||||
if account.status != AccountStatus.PENDING.value:
|
||||
if account.status != AccountStatus.PENDING:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
token = cls.generate_invite_token(tenant, account)
|
||||
|
||||
@ -29,6 +29,7 @@ from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
@ -439,6 +440,7 @@ class AppDslService:
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
|
||||
app.updated_by = account.id
|
||||
app.updated_at = naive_utc_now()
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
@ -494,7 +496,7 @@ class AppDslService:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
@ -584,17 +586,17 @@ class AppDslService:
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node_data.get("dataset_ids", [])
|
||||
node_data["dataset_ids"] = [
|
||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if not include_secret and data_type == NodeType.TOOL.value:
|
||||
if not include_secret and data_type == NodeType.TOOL:
|
||||
node_data.pop("credential_id", None)
|
||||
# filter credential id from agent node
|
||||
if not include_secret and data_type == NodeType.AGENT.value:
|
||||
if not include_secret and data_type == NodeType.AGENT:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
|
||||
@ -658,32 +660,32 @@ class AppDslService:
|
||||
try:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
case NodeType.TOOL:
|
||||
tool_entity = ToolNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
case NodeType.LLM:
|
||||
llm_entity = LLMNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
@ -773,7 +775,7 @@ class AppDslService:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
|
||||
@ -26,10 +26,9 @@ class ApiKeyAuthService:
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
@ -48,6 +47,8 @@ class ApiKeyAuthService:
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -1470,7 +1470,7 @@ class DocumentService:
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
@ -1752,7 +1752,7 @@ class DocumentService:
|
||||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
@ -2205,7 +2205,7 @@ class DocumentService:
|
||||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=4,
|
||||
|
||||
@ -646,7 +646,7 @@ class DatasourceProviderService:
|
||||
name=db_provider_name,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
auth_type=CredentialType.API_KEY.value,
|
||||
auth_type=CredentialType.API_KEY,
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
@ -674,7 +674,7 @@ class DatasourceProviderService:
|
||||
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.name)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
|
||||
class BaseRequest:
|
||||
proxies = {
|
||||
proxies: Mapping[str, str] | None = {
|
||||
"http": "",
|
||||
"https": "",
|
||||
}
|
||||
@ -13,10 +15,31 @@ class BaseRequest:
|
||||
secret_key_header = ""
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
def _build_mounts(cls) -> dict[str, httpx.BaseTransport] | None:
|
||||
if not cls.proxies:
|
||||
return None
|
||||
|
||||
mounts: dict[str, httpx.BaseTransport] = {}
|
||||
for scheme, value in cls.proxies.items():
|
||||
if not value:
|
||||
continue
|
||||
key = f"{scheme}://" if not scheme.endswith("://") else scheme
|
||||
mounts[key] = httpx.HTTPTransport(proxy=value)
|
||||
return mounts or None
|
||||
|
||||
@classmethod
|
||||
def send_request(
|
||||
cls,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: Any | None = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
|
||||
mounts = cls._build_mounts()
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
response = client.request(method, url, json=json, params=params, headers=headers)
|
||||
return response.json()
|
||||
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class EnterpriseService:
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
return WebAppSettings.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
|
||||
@ -100,7 +100,7 @@ class EnterpriseService:
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
return WebAppSettings.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
||||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import (
|
||||
@ -26,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
class CustomConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
@ -71,7 +72,7 @@ class ProviderResponse(BaseModel):
|
||||
icon_large: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
@ -82,9 +83,8 @@ class ProviderResponse(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@ -97,6 +97,7 @@ class ProviderResponse(BaseModel):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ProviderWithModelsResponse(BaseModel):
|
||||
@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class DefaultModelResponse(BaseModel):
|
||||
|
||||
@ -88,9 +88,9 @@ class ExternalDatasetService:
|
||||
else:
|
||||
raise ValueError(f"invalid endpoint: {endpoint}")
|
||||
try:
|
||||
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
|
||||
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
|
||||
if response.status_code == 502:
|
||||
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
|
||||
if response.status_code == 404:
|
||||
|
||||
@ -15,7 +15,7 @@ from models.dataset import Dataset, DatasetQuery
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
@ -46,7 +46,7 @@ class HitTestingService:
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
@ -63,7 +63,7 @@ class HitTestingService:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
|
||||
@ -123,7 +123,7 @@ class OpsService:
|
||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||
other_keys: list[str] = provider_config["other_keys"]
|
||||
|
||||
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
|
||||
default_config_instance = config_class.model_validate(tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||
|
||||
@ -242,7 +242,7 @@ class PluginMigration:
|
||||
if data.get("type") == "tool":
|
||||
provider_name = data.get("provider_name")
|
||||
provider_type = data.get("provider_type")
|
||||
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
|
||||
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN:
|
||||
result.append(ToolProviderID(provider_name).plugin_id)
|
||||
|
||||
return result
|
||||
@ -269,9 +269,9 @@ class PluginMigration:
|
||||
for tool in agent_config["tools"]:
|
||||
if isinstance(tool, dict):
|
||||
try:
|
||||
tool_entity = AgentToolEntity(**tool)
|
||||
tool_entity = AgentToolEntity.model_validate(tool)
|
||||
if (
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN
|
||||
and tool_entity.provider_id not in excluded_providers
|
||||
):
|
||||
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
@ -43,7 +43,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipeline-templates/{template_id}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
data: dict = response.json()
|
||||
@ -58,7 +58,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipeline-templates?language={language}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
|
||||
|
||||
|
||||
@ -358,7 +358,7 @@ class RagPipelineService:
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
@ -873,7 +873,7 @@ class RagPipelineService:
|
||||
variable_pool = node_instance.graph_runtime_state.variable_pool
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED.value:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
|
||||
@ -288,7 +288,7 @@ class RagPipelineDslService:
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if (
|
||||
dataset
|
||||
and pipeline.is_published
|
||||
@ -426,7 +426,7 @@ class RagPipelineDslService:
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if not dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=account.current_tenant_id,
|
||||
@ -556,7 +556,7 @@ class RagPipelineDslService:
|
||||
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
@ -613,7 +613,7 @@ class RagPipelineDslService:
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
features="{}",
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
type=WorkflowType.RAG_PIPELINE,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id,
|
||||
@ -689,17 +689,17 @@ class RagPipelineDslService:
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node_data.get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if not include_secret and data_type == NodeType.TOOL.value:
|
||||
if not include_secret and data_type == NodeType.TOOL:
|
||||
node_data.pop("credential_id", None)
|
||||
# filter credential id from agent node
|
||||
if not include_secret and data_type == NodeType.AGENT.value:
|
||||
if not include_secret and data_type == NodeType.AGENT:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
|
||||
@ -733,36 +733,36 @@ class RagPipelineDslService:
|
||||
try:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
case NodeType.TOOL:
|
||||
tool_entity = ToolNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.DATASOURCE.value:
|
||||
datasource_entity = DatasourceNodeData(**node["data"])
|
||||
case NodeType.DATASOURCE:
|
||||
datasource_entity = DatasourceNodeData.model_validate(node["data"])
|
||||
if datasource_entity.provider_type != "local_file":
|
||||
dependencies.append(datasource_entity.plugin_id)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
case NodeType.LLM:
|
||||
llm_entity = LLMNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_INDEX.value:
|
||||
knowledge_index_entity = KnowledgeConfiguration(**node["data"])
|
||||
case NodeType.KNOWLEDGE_INDEX:
|
||||
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
|
||||
if knowledge_index_entity.indexing_technique == "high_quality":
|
||||
if knowledge_index_entity.embedding_model_provider:
|
||||
dependencies.append(
|
||||
@ -782,8 +782,8 @@ class RagPipelineDslService:
|
||||
knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
@ -873,7 +873,7 @@ class RagPipelineDslService:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
@ -927,7 +927,7 @@ class RagPipelineDslService:
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=None,
|
||||
dataset_name=rag_pipeline_dataset_create_entity.name,
|
||||
|
||||
@ -9,6 +9,7 @@ from flask_login import current_user
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
@ -156,15 +157,15 @@ class RagPipelineTransformService:
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
|
||||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting(**retrieval_model)
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
@ -214,7 +215,7 @@ class RagPipelineTransformService:
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
features="{}",
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
type=WorkflowType.RAG_PIPELINE,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=current_user.id,
|
||||
@ -226,7 +227,7 @@ class RagPipelineTransformService:
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
features="{}",
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
type=WorkflowType.RAG_PIPELINE,
|
||||
version=str(datetime.now(UTC).replace(tzinfo=None)),
|
||||
graph=json.dumps(graph),
|
||||
created_by=current_user.id,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
@ -43,7 +43,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps/{app_id}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
data: dict = response.json()
|
||||
@ -58,7 +58,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps?language={language}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
|
||||
@ -148,7 +148,7 @@ class ApiToolManageService:
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
@ -277,7 +277,7 @@ class ApiToolManageService:
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
@ -393,7 +393,7 @@ class ApiToolManageService:
|
||||
icon="",
|
||||
schema=schema,
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
@ -683,7 +683,7 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||
new_params: dict = {
|
||||
new_params = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
|
||||
@ -50,16 +50,16 @@ class ToolTransformService:
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP.value:
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return icon
|
||||
return ""
|
||||
|
||||
@ -152,7 +152,8 @@ class ToolTransformService:
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
if not db_provider.tenant_id:
|
||||
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
|
||||
# init tool configuration
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
@ -381,6 +382,7 @@ class ToolTransformService:
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
assert tool.operation_id
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
|
||||
@ -134,7 +134,7 @@ class VectorService:
|
||||
)
|
||||
# use full doc mode to generate segment's child chunk
|
||||
processing_rule_dict = processing_rule.to_dict()
|
||||
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
|
||||
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
|
||||
documents = index_processor.transform(
|
||||
[document],
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
|
||||
@ -36,7 +36,7 @@ class WebAppAuthService:
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise AccountLoginError("Account is banned.")
|
||||
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
@ -56,7 +56,7 @@ class WebAppAuthService:
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
return account
|
||||
|
||||
@ -228,7 +228,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"type": NodeType.START,
|
||||
"variables": [jsonable_encoder(v) for v in variables],
|
||||
},
|
||||
}
|
||||
@ -273,7 +273,7 @@ class WorkflowConverter:
|
||||
inputs[v.variable] = "{{#start." + v.variable + "#}}"
|
||||
|
||||
request_body = {
|
||||
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
||||
"params": {
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
@ -290,7 +290,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": f"HTTP REQUEST {api_based_extension.name}",
|
||||
"type": NodeType.HTTP_REQUEST.value,
|
||||
"type": NodeType.HTTP_REQUEST,
|
||||
"method": "post",
|
||||
"url": api_based_extension.api_endpoint,
|
||||
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
|
||||
@ -308,7 +308,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": f"Parse {api_based_extension.name} Response",
|
||||
"type": NodeType.CODE.value,
|
||||
"type": NodeType.CODE,
|
||||
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
|
||||
"code_language": "python3",
|
||||
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
|
||||
@ -348,7 +348,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "KNOWLEDGE RETRIEVAL",
|
||||
"type": NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
"type": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
"query_variable_selector": query_variable_selector,
|
||||
"dataset_ids": dataset_config.dataset_ids,
|
||||
"retrieval_mode": retrieve_config.retrieve_strategy.value,
|
||||
@ -396,16 +396,16 @@ class WorkflowConverter:
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
"""
|
||||
# fetch start and knowledge retrieval node
|
||||
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
|
||||
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"]))
|
||||
knowledge_retrieval_node = next(
|
||||
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
|
||||
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
prompts: Any | None = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
if model_config.mode == LLMMode.CHAT:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
@ -517,7 +517,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "LLM",
|
||||
"type": NodeType.LLM.value,
|
||||
"type": NodeType.LLM,
|
||||
"model": {
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
@ -572,7 +572,7 @@ class WorkflowConverter:
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"type": NodeType.END,
|
||||
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
|
||||
},
|
||||
}
|
||||
@ -586,7 +586,7 @@ class WorkflowConverter:
|
||||
return {
|
||||
"id": "answer",
|
||||
"position": None,
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str):
|
||||
|
||||
@ -569,7 +569,7 @@ class WorkflowDraftVariableService:
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.DEBUGGER.value,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=account_id,
|
||||
|
||||
@ -74,7 +74,7 @@ class WorkflowRunService:
|
||||
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=limit,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
@ -1006,7 +1006,7 @@ def _setup_variable_pool(
|
||||
)
|
||||
|
||||
# Only add chatflow-specific variables for non-workflow types
|
||||
if workflow.type != WorkflowType.WORKFLOW.value:
|
||||
if workflow.type != WorkflowType.WORKFLOW:
|
||||
system_variable.query = query
|
||||
system_variable.conversation_id = conversation_id
|
||||
system_variable.dialogue_count = 1
|
||||
|
||||
Reference in New Issue
Block a user