mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 03:07:39 +08:00
Merge branch 'main' into jzh
This commit is contained in:
@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=True
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
|
||||
|
||||
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=False
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional, TypedDict
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
@ -20,6 +20,23 @@ class DatasourceApiEntity(BaseModel):
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntityDict(TypedDict):
|
||||
id: str
|
||||
author: str
|
||||
name: str
|
||||
plugin_id: str | None
|
||||
plugin_unique_identifier: str | None
|
||||
description: I18nObjectDict
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> DatasourceProviderApiEntityDict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
result: DatasourceProviderApiEntityDict = {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except RequestError:
|
||||
except (RequestError, json.JSONDecodeError, IndexError):
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
||||
@ -61,27 +61,28 @@ class TokenBufferMemory:
|
||||
:param is_user_message: whether this is a user message
|
||||
:return: PromptMessage
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
match self.conversation.mode:
|
||||
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
case _:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file_extra_config and app_record:
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -19,6 +20,16 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[str]
|
||||
verify_certs: bool
|
||||
ssl_show_warn: bool
|
||||
request_timeout: int
|
||||
retry_on_timeout: bool
|
||||
max_retries: int
|
||||
basic_auth: tuple[str, str]
|
||||
|
||||
|
||||
def create_ssl_context() -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
|
||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts.split(","),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 30000,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 10,
|
||||
}
|
||||
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
|
||||
params = HuaweiElasticsearchParamsDict(
|
||||
hosts=self.hosts.split(","),
|
||||
verify_certs=False,
|
||||
ssl_show_warn=False,
|
||||
request_timeout=30000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["basic_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
|
||||
UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
|
||||
class LindormOpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: str | None
|
||||
use_ssl: bool
|
||||
pool_maxsize: int
|
||||
timeout: int
|
||||
http_auth: tuple[str, str]
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str | None
|
||||
username: str | None = None
|
||||
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {
|
||||
"hosts": self.hosts,
|
||||
"use_ssl": False,
|
||||
"pool_maxsize": 128,
|
||||
"timeout": 30,
|
||||
}
|
||||
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
|
||||
params = LindormOpenSearchParamsDict(
|
||||
hosts=self.hosts,
|
||||
use_ssl=False,
|
||||
pool_maxsize=128,
|
||||
timeout=30,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
@ -21,6 +22,20 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OpenSearchHostDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
class OpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[_OpenSearchHostDict]
|
||||
use_ssl: bool
|
||||
verify_certs: bool
|
||||
connection_class: type
|
||||
pool_maxsize: int
|
||||
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.verify_certs,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
def to_opensearch_params(self) -> OpenSearchParamsDict:
|
||||
params = OpenSearchParamsDict(
|
||||
hosts=[{"host": self.host, "port": self.port}],
|
||||
use_ssl=self.secure,
|
||||
verify_certs=self.verify_certs,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
@ -5,12 +5,30 @@ from typing import Any
|
||||
import pytz # type: ignore[import-untyped]
|
||||
from celery import Celery, Task
|
||||
from celery.schedules import crontab
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
socket_timeout: float | None
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return ssl_options
|
||||
return CelerySSLOptionsDict(
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
|
||||
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
|
||||
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
|
||||
)
|
||||
|
||||
|
||||
def get_celery_broker_transport_options() -> dict[str, Any]:
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return {
|
||||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
return CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False},
|
||||
)
|
||||
return self._get_enabled_config(self.suggested_questions_after_answer)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
return self._get_enabled_config(self.speech_to_text)
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
return self._get_enabled_config(self.text_to_speech)
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
return self._get_enabled_config(self.more_like_this)
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
class TrialApp(TypeBase):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
@ -911,18 +907,26 @@ class TrialApp(Base):
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
class AccountTrialAppRecord(TypeBase):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base):
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
|
||||
@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
|
||||
record_id = resolve_file_record_id(mapping)
|
||||
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
|
||||
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE if record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
case FileTransferMethod.DATASOURCE_FILE if record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
case _:
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
|
||||
@ -467,61 +467,67 @@ class AppDslService:
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj)
|
||||
for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id, tenant_id=app.tenant_id
|
||||
)
|
||||
)
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode")
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -132,8 +132,8 @@ class FileService:
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = self._session_maker(expire_on_commit=False).scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -178,7 +178,7 @@ class FileService:
|
||||
Return a short text preview extracted from a document file.
|
||||
"""
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -200,7 +200,7 @@ class FileService:
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -220,7 +220,7 @@ class FileService:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -231,7 +231,7 @@ class FileService:
|
||||
|
||||
def get_public_image_preview(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -247,7 +247,7 @@ class FileService:
|
||||
|
||||
def get_file_content(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants.languages import languages
|
||||
@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
|
||||
class RecommendedAppItemDict(TypedDict):
|
||||
id: str
|
||||
app: App | None
|
||||
app_id: str
|
||||
description: Any
|
||||
copyright: Any
|
||||
privacy_policy: Any
|
||||
custom_disclaimer: str
|
||||
category: str
|
||||
position: int
|
||||
is_listed: bool
|
||||
|
||||
|
||||
class RecommendedAppsResultDict(TypedDict):
|
||||
recommended_apps: list[RecommendedAppItemDict]
|
||||
categories: list[str]
|
||||
|
||||
|
||||
class RecommendedAppDetailDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
icon: Any
|
||||
icon_background: str | None
|
||||
mode: str
|
||||
export_data: str
|
||||
|
||||
|
||||
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict:
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
def get_recommend_app_detail(self, app_id: str):
|
||||
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
result = self.fetch_recommended_app_detail_from_db(app_id)
|
||||
return result
|
||||
|
||||
@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str):
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict:
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
recommended_apps_result: list[RecommendedAppItemDict] = []
|
||||
for recommended_app in recommended_apps:
|
||||
app = recommended_app.app
|
||||
if not app or not app.is_public:
|
||||
@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
recommended_app_result: RecommendedAppItemDict = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not app_model or not app_model.is_public:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
}
|
||||
return RecommendedAppDetailDict(
|
||||
id=app_model.id,
|
||||
name=app_model.name,
|
||||
icon=app_model.icon,
|
||||
icon_background=app_model.icon_background,
|
||||
mode=app_model.mode,
|
||||
export_data=AppDslService.export_dsl(app_model=app_model),
|
||||
)
|
||||
|
||||
@ -104,32 +104,32 @@ class WebhookService:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1)
|
||||
)
|
||||
if not webhook_trigger:
|
||||
raise ValueError(f"Webhook not found: {webhook_id}")
|
||||
|
||||
if is_debug:
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
# Check if the corresponding AppTrigger exists
|
||||
app_trigger = (
|
||||
session.query(AppTrigger)
|
||||
.filter(
|
||||
app_trigger = session.scalar(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not app_trigger:
|
||||
@ -146,14 +146,14 @@ class WebhookService:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
|
||||
@ -862,6 +862,15 @@ class TestAuthOrchestration:
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
# JSONDecodeError (non-JSON 200 response)
|
||||
mock_get.side_effect = None
|
||||
bad_json_response = Mock()
|
||||
bad_json_response.status_code = 200
|
||||
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_response
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_authorization_server_metadata(self, mock_get):
|
||||
# Success
|
||||
@ -892,6 +901,14 @@ class TestAuthOrchestration:
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
# JSONDecodeError (non-JSON 200 response)
|
||||
bad_json_response = Mock()
|
||||
bad_json_response.status_code = 200
|
||||
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
def test_get_effective_scope(self):
|
||||
prm = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
@ -997,6 +1014,24 @@ class TestAuthOrchestration:
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 6: JSONDecodeError (non-JSON 200 response)
|
||||
mock_get.side_effect = None
|
||||
bad_json_res = Mock()
|
||||
bad_json_res.status_code = 200
|
||||
bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
assert url == ""
|
||||
|
||||
# Case 7: Empty authorization_servers array (IndexError)
|
||||
empty_res = Mock()
|
||||
empty_res.status_code = 200
|
||||
empty_res.json.return_value = {"authorization_servers": []}
|
||||
mock_get.return_value = empty_res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
def test_discover_oauth_metadata(self):
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
|
||||
@ -165,7 +165,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "test_key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load_once.return_value = b"test content"
|
||||
@ -178,7 +178,7 @@ class TestFileService:
|
||||
mock_storage.load_once.assert_called_once_with("test_key")
|
||||
|
||||
def test_get_file_base64_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_base64("non_existent")
|
||||
|
||||
@ -215,7 +215,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "pdf"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
|
||||
mock_extract.return_value = "Extracted text content"
|
||||
@ -227,7 +227,7 @@ class TestFileService:
|
||||
assert result == "Extracted text content"
|
||||
|
||||
def test_get_file_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_preview("non_existent")
|
||||
|
||||
@ -235,7 +235,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "exe"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_file_preview("file_id")
|
||||
|
||||
@ -246,7 +246,7 @@ class TestFileService:
|
||||
upload_file.extension = "jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
|
||||
@ -269,7 +269,7 @@ class TestFileService:
|
||||
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
@ -279,7 +279,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
@ -289,7 +289,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
|
||||
@ -309,7 +309,7 @@ class TestFileService:
|
||||
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
@ -321,7 +321,7 @@ class TestFileService:
|
||||
upload_file.extension = "png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"image content"
|
||||
@ -330,7 +330,7 @@ class TestFileService:
|
||||
assert mime == "image/png"
|
||||
|
||||
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
@ -338,7 +338,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
@ -346,7 +346,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"hello world"
|
||||
@ -354,7 +354,7 @@ class TestFileService:
|
||||
assert result == "hello world"
|
||||
|
||||
def test_get_file_content_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_content("file_id")
|
||||
|
||||
|
||||
@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App:
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value = _FakeQuery(None)
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
|
||||
Reference in New Issue
Block a user