Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -24,7 +24,7 @@ class ToolParameterConfigurationManager:
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
):
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name

View File

@ -3,6 +3,7 @@ from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []

View File

@ -1,7 +1,5 @@
from abc import abstractmethod
from typing import Optional
from abc import ABC, abstractmethod
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@ -14,7 +12,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 4
score_threshold: Optional[float] = None
score_threshold: float | None = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str

View File

@ -1,6 +1,7 @@
from typing import Any, Optional, cast
from typing import Any, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
@ -36,7 +37,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
user_id: Optional[str] = None
user_id: str | None = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, Optional
from typing import Any
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@ -20,7 +20,7 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
class DatasetRetrieverTool(Tool):
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@ -87,9 +87,9 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> list[ToolParameter]:
return [
ToolParameter(
@ -112,9 +112,9 @@ class DatasetRetrieverTool(Tool):
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional
from uuid import UUID
import numpy as np
@ -60,7 +59,7 @@ class ToolFileMessageTransformer:
messages: Generator[ToolInvokeMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
conversation_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Transform tool message and handle file download
@ -165,5 +164,5 @@ class ToolFileMessageTransformer:
yield message
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}"

View File

@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import Optional, cast
from typing import cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
@ -51,7 +51,7 @@ class ModelInvocationUtils:
if not schema:
raise InvokeModelError("No model schema found")
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048

View File

@ -2,7 +2,6 @@ import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
@ -198,9 +197,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: Optional[str] = None
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
@ -242,7 +241,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
warning = warning or {}
"""
parse swagger to openapi

View File

@ -2,7 +2,7 @@ import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
@ -28,7 +28,7 @@ class SystemOAuthEncrypter:
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: Optional[str] = None):
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
@ -130,7 +130,7 @@ class SystemOAuthEncrypter:
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
@ -144,7 +144,7 @@ def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAu
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
_oauth_encrypter: SystemOAuthEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:

View File

@ -2,7 +2,7 @@ import mimetypes
import re
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, cast
from typing import Any, cast
from urllib.parse import unquote
import chardet
@ -27,7 +27,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str:
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: Optional[str] = None) -> str:
def get_url(url: str, user_agent: str | None = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"

View File

@ -1,4 +1,5 @@
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any
@ -8,28 +9,25 @@ from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
def _load_yaml_file(*, file_path: str):
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
return yaml_content
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
@lru_cache(maxsize=128)
def load_yaml_file_cached(file_path: str) -> Any:
"""
Cached version of load_yaml_file for static configuration files.
Only use for files that don't change during runtime (e.g., position files)
:param file_path: the path of the YAML file
:return: an object of the YAML content
"""
return _load_yaml_file(file_path=file_path)