Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN-
2025-09-10 02:03:45 +08:00
99 changed files with 849 additions and 494 deletions

View File

@ -1318,7 +1318,7 @@ class RegisterService:
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@ -1355,7 +1355,7 @@ class RegisterService:
}
@classmethod
def _get_invitation_by_token(
def get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None:

View File

@ -349,7 +349,7 @@ class AppAnnotationService:
try:
# Skip the first row
df = pd.read_csv(file, dtype=str)
df = pd.read_csv(file.stream, dtype=str)
result = []
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
@ -463,15 +463,23 @@ class AppAnnotationService:
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False}
@classmethod
@ -506,15 +514,23 @@ class AppAnnotationService:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
@classmethod
def clear_all_annotations(cls, app_id: str):

View File

@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

View File

@ -134,11 +134,14 @@ class DatasetService:
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids(
"knowledge",
tenant_id, # ty: ignore [invalid-argument-type]
tag_ids,
)
if tenant_id is not None:
target_ids = TagService.get_target_ids_by_tag_ids(
"knowledge",
tenant_id,
tag_ids,
)
else:
target_ids = []
if target_ids and len(target_ids) > 0:
query = query.where(Dataset.id.in_(target_ids))
else:
@ -987,7 +990,8 @@ class DocumentService:
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
for document in documents:
db.session.delete(document)
@ -2688,56 +2692,6 @@ class SegmentService:
return paginated_segments.items, paginated_segments.total
@classmethod
def update_segment_by_id(
cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting if high quality
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=user_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# validate and update segment
cls.segment_create_args_validate(segment_data, document)
updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset)
return updated_segment, document
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""

View File

@ -181,7 +181,7 @@ class ExternalDatasetService:
do http request depending on api bundle
"""
kwargs = {
kwargs: dict[str, Any] = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,

View File

@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Any, Literal, Union
from typing import Literal, Union
from werkzeug.exceptions import NotFound
@ -35,7 +35,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser, Any],
user: Union[Account, EndUser],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:

View File

@ -165,7 +165,7 @@ class ModelLoadBalancingService:
try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
@ -180,11 +180,13 @@ class ModelLoadBalancingService:
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable), # ty: ignore [invalid-argument-type]
decoding_rsa_key,
decoding_cipher_rsa,
)
token_value = credentials.get(variable)
if isinstance(token_value, str):
credentials[variable] = encrypter.decrypt_token_with_decoding(
token_value,
decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError:
pass
@ -345,8 +347,9 @@ class ModelLoadBalancingService:
credential_id = config.get("credential_id")
enabled = config.get("enabled")
credential_record: ProviderCredential | ProviderModelCredential | None = None
if credential_id:
credential_record: ProviderCredential | ProviderModelCredential | None = None
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)

View File

@ -100,6 +100,7 @@ class PluginMigration:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

View File

@ -223,8 +223,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -285,9 +285,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod

View File

@ -18,6 +18,7 @@ from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
@ -420,7 +421,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
if not template:
prompts = []
else:
@ -457,7 +462,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
template = self._replace_template_variables(
template=template,
variables=start_node["data"]["variables"],
@ -467,6 +476,9 @@ class WorkflowConverter:
prompts = {"text": template}
prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),

View File

@ -783,11 +783,13 @@ class WorkflowService:
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
return node, node_run_result, run_succeeded, error
except WorkflowNodeRunFailedError as e:
return e._node, None, False, e._error
node = e.node
run_succeeded = False
node_run_result = None
error = e.error
return node, node_run_result, run_succeeded, error
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
"""Apply error strategy when node execution fails."""

View File

@ -12,7 +12,7 @@ class WorkspaceService:
def get_tenant_info(cls, tenant: Tenant):
if not tenant:
return None
tenant_info = {
tenant_info: dict[str, object] = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,