mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge branch 'feat/collaboration' into deploy/dev
This commit is contained in:
@ -26,10 +26,9 @@ class ApiKeyAuthService:
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
@ -48,6 +47,8 @@ class ApiKeyAuthService:
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -1470,7 +1470,7 @@ class DocumentService:
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
@ -1752,7 +1752,7 @@ class DocumentService:
|
||||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
@ -2205,7 +2205,7 @@ class DocumentService:
|
||||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=4,
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
||||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
class CustomConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
||||
@ -88,9 +88,9 @@ class ExternalDatasetService:
|
||||
else:
|
||||
raise ValueError(f"invalid endpoint: {endpoint}")
|
||||
try:
|
||||
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
|
||||
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
|
||||
if response.status_code == 502:
|
||||
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
|
||||
if response.status_code == 404:
|
||||
|
||||
@ -63,7 +63,7 @@ class HitTestingService:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
|
||||
@ -9,6 +9,7 @@ from flask_login import current_user
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
@ -164,7 +165,7 @@ class RagPipelineTransformService:
|
||||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
||||
@ -148,7 +148,7 @@ class ApiToolManageService:
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
@ -683,7 +683,7 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||
new_params: dict = {
|
||||
new_params = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
|
||||
@ -188,6 +188,8 @@ class MCPToolManageService:
|
||||
raise
|
||||
|
||||
user = mcp_provider.load_user()
|
||||
if not mcp_provider.icon:
|
||||
raise ValueError("MCP provider icon is required")
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
|
||||
@ -152,7 +152,8 @@ class ToolTransformService:
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
if not db_provider.tenant_id:
|
||||
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
|
||||
# init tool configuration
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
||||
Reference in New Issue
Block a user