mirror of
https://github.com/langgenius/dify.git
synced 2026-02-12 22:35:46 +08:00
Compare commits
9 Commits
review-mys
...
deploy/dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a7edd2ae0 | |||
| b3faa34a03 | |||
| 5c05f09e23 | |||
| ff8d7855b1 | |||
| 02206cc64b | |||
| cba157038f | |||
| d606de26f1 | |||
| 49f46a05e7 | |||
| 30b73f2765 |
@ -136,7 +136,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
|
||||
@ -38,6 +38,7 @@ from . import (
|
||||
extension,
|
||||
feature,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@ -182,6 +183,7 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@ -6,7 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@ -277,3 +280,113 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class SaveNotificationContentPayload(BaseModel):
|
||||
content: str = Field(...)
|
||||
|
||||
|
||||
class SaveNotificationUserPayload(BaseModel):
|
||||
user_email: list[str] = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SaveNotificationContentPayload.__name__,
|
||||
SaveNotificationContentPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
SaveNotificationUserPayload.__name__,
|
||||
SaveNotificationUserPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/save_notification_content")
|
||||
class SaveNotificationContentApi(Resource):
|
||||
@console_ns.doc("save_notification_content")
|
||||
@console_ns.doc(description="Save a notification content")
|
||||
@console_ns.expect(console_ns.models[SaveNotificationContentPayload.__name__])
|
||||
@console_ns.response(200, "Notification content saved successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = SaveNotificationContentPayload.model_validate(console_ns.payload)
|
||||
BillingService.save_notification_content(payload.content)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/save_notification_user")
|
||||
class SaveNotificationUserApi(Resource):
|
||||
@console_ns.doc("save_notification_user")
|
||||
@console_ns.doc(description="Save notification users via JSON body or file upload. "
|
||||
"JSON: {\"user_email\": [\"a@example.com\", ...]}. "
|
||||
"File: multipart/form-data with a 'file' field (CSV or TXT, one email per line).")
|
||||
@console_ns.response(200, "Notification users saved successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
# Determine input mode: file upload or JSON body
|
||||
if "file" in request.files:
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = SaveNotificationUserPayload.model_validate(console_ns.payload)
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Use batch API for bulk insert (chunks of 1000 per request to billing service)
|
||||
result = BillingService.save_notification_users_batch(emails)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"total": len(emails),
|
||||
"succeeded": result["succeeded"],
|
||||
"failed_chunks": result["failed_chunks"],
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
emails.append(cell)
|
||||
else:
|
||||
# TXT file: one email per line
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
email_lower = email.lower()
|
||||
if email_lower not in seen:
|
||||
seen.add(email_lower)
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
26
api/controllers/console/notification.py
Normal file
26
api/controllers/console/notification.py
Normal file
@ -0,0 +1,26 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(description="Get notification for the current user")
|
||||
@console_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
401: "Unauthorized",
|
||||
}
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
notification = BillingService.read_notification(current_user.email)
|
||||
return notification
|
||||
@ -47,6 +47,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
@ -68,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
@ -122,6 +126,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
|
||||
@ -33,18 +33,6 @@ class SortOrder(StrEnum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
_METADATA_KEY_WHITELIST = {
|
||||
"annotation_id",
|
||||
"app_id",
|
||||
"batch",
|
||||
"dataset_id",
|
||||
"doc_hash",
|
||||
"doc_id",
|
||||
"document_id",
|
||||
"lang",
|
||||
"source",
|
||||
}
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@ -57,17 +45,10 @@ class MyScaleVector(BaseVector):
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
self._qualified_table = f"{self._config.database}.{self._collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
@classmethod
|
||||
def _validate_metadata_key(cls, key: str) -> str:
|
||||
if key not in cls._METADATA_KEY_WHITELIST:
|
||||
raise ValueError(f"Unsupported metadata key: {key!r}")
|
||||
return key
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@ -78,7 +59,7 @@ class MyScaleVector(BaseVector):
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._qualified_table}(
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
@ -93,98 +74,73 @@ class MyScaleVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
rows = []
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
rows.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata or {}),
|
||||
)
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
if rows:
|
||||
self._client.insert(self._qualified_table, rows, column_names=columns)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(
|
||||
f"SELECT id FROM {self._qualified_table} WHERE id = %(id)s LIMIT 1",
|
||||
parameters={"id": id},
|
||||
)
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
placeholders, params = self._build_in_params("id", ids)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._qualified_table} WHERE id IN ({placeholders})",
|
||||
parameters=params,
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(
|
||||
"TextSearch('enable_nlq=false')(text, %(query)s)",
|
||||
SortOrder.DESC,
|
||||
parameters={"query": query},
|
||||
**kwargs,
|
||||
)
|
||||
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _build_in_params(prefix: str, values: list[str]) -> tuple[str, dict[str, str]]:
|
||||
params: dict[str, str] = {}
|
||||
placeholders = []
|
||||
for i, value in enumerate(values):
|
||||
name = f"{prefix}_{i}"
|
||||
placeholders.append(f"%({name})s")
|
||||
params[name] = value
|
||||
return ", ".join(placeholders), params
|
||||
|
||||
def _search(
|
||||
self,
|
||||
dist: str,
|
||||
order: SortOrder,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_clauses = []
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0:
|
||||
where_clauses.append(f"dist < {1 - score_threshold}")
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
query_params = dict(parameters or {})
|
||||
if document_ids_filter:
|
||||
placeholders, params = self._build_in_params("document_id", document_ids_filter)
|
||||
where_clauses.append(f"metadata['document_id'] IN ({placeholders})")
|
||||
query_params.update(params)
|
||||
where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._qualified_table}
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
@ -194,14 +150,14 @@ class MyScaleVector(BaseVector):
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql, parameters=query_params).named_results()
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._qualified_table}")
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -16,12 +15,13 @@ if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
_max_output_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,6 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
max_output_length: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -40,6 +41,10 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
if max_output_length is not None and max_output_length <= 0:
|
||||
raise ValueError("max_output_length must be a positive integer")
|
||||
self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
@ -69,11 +74,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > self._max_output_length:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
|
||||
error=f"Output length exceeds {self._max_output_length} characters",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
|
||||
@ -393,3 +393,35 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def read_notification(cls, user_email: str):
|
||||
params = {"user_email": user_email}
|
||||
return cls._send_request("GET", "/notification/read", params=params)
|
||||
|
||||
@classmethod
|
||||
def save_notification_user(cls, user_email: str):
|
||||
json = {"user_email": user_email}
|
||||
return cls._send_request("POST", "/notification/new-notification-user", json=json)
|
||||
|
||||
@classmethod
|
||||
def save_notification_users_batch(cls, user_emails: list[str]) -> dict:
|
||||
"""Batch save notification users in chunks of 1000."""
|
||||
chunk_size = 1000
|
||||
total_succeeded = 0
|
||||
failed_chunks: list[dict] = []
|
||||
|
||||
for i in range(0, len(user_emails), chunk_size):
|
||||
chunk = user_emails[i : i + chunk_size]
|
||||
try:
|
||||
resp = cls._send_request("POST", "/notification/batch-notification-users", json={"user_emails": chunk})
|
||||
total_succeeded += resp.get("count", len(chunk))
|
||||
except Exception as e:
|
||||
failed_chunks.append({"offset": i, "count": len(chunk), "error": str(e)})
|
||||
|
||||
return {"succeeded": total_succeeded, "failed_chunks": failed_chunks}
|
||||
|
||||
@classmethod
|
||||
def save_notification_content(cls, content: str):
|
||||
json = {"content": content}
|
||||
return cls._send_request("POST", "/notification/new-notification", json=json)
|
||||
@ -1,6 +1,7 @@
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
@ -53,7 +54,12 @@ class WorkspaceService:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
|
||||
if paid_pool:
|
||||
# if the tenant is not on the sandbox plan and the paid pool is not full, use the paid pool
|
||||
if (
|
||||
feature.billing.subscription.plan != CloudPlan.SANDBOX
|
||||
and paid_pool is not None
|
||||
and paid_pool.quota_limit > paid_pool.quota_used
|
||||
):
|
||||
tenant_info["trial_credits"] = paid_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = paid_pool.quota_used
|
||||
else:
|
||||
|
||||
@ -217,7 +217,6 @@ class TestTemplateTransformNode:
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
@ -231,6 +230,7 @@ class TestTemplateTransformNode:
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
max_output_length=10,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -2031,9 +2031,6 @@ describe('CommonCreateModal', () => {
|
||||
expect(mockCreateBuilder).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Flush pending state updates from createBuilder promise resolution
|
||||
await act(async () => {})
|
||||
|
||||
const input = screen.getByTestId('form-field-webhook_url')
|
||||
fireEvent.change(input, { target: { value: 'test' } })
|
||||
|
||||
|
||||
@ -613,11 +613,6 @@ describe('UpdateDSLModal', () => {
|
||||
expect(importButton).not.toBeDisabled()
|
||||
})
|
||||
|
||||
// Flush the FileReader microtask to ensure fileContent is set
|
||||
await act(async () => {
|
||||
await new Promise<void>(resolve => queueMicrotask(resolve))
|
||||
})
|
||||
|
||||
const importButton = screen.getByText('common.overwriteAndImport')
|
||||
fireEvent.click(importButton)
|
||||
|
||||
@ -766,8 +761,6 @@ describe('UpdateDSLModal', () => {
|
||||
})
|
||||
|
||||
it('should call importDSLConfirm when confirm button is clicked in error modal', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockImportDSL.mockResolvedValue({
|
||||
id: 'import-id',
|
||||
status: DSLImportStatus.PENDING,
|
||||
@ -785,27 +778,20 @@ describe('UpdateDSLModal', () => {
|
||||
|
||||
const fileInput = screen.getByTestId('file-input')
|
||||
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
|
||||
fireEvent.change(fileInput, { target: { files: [file] } })
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.change(fileInput, { target: { files: [file] } })
|
||||
// Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask)
|
||||
await new Promise<void>(resolve => queueMicrotask(resolve))
|
||||
await waitFor(() => {
|
||||
const importButton = screen.getByText('common.overwriteAndImport')
|
||||
expect(importButton).not.toBeDisabled()
|
||||
})
|
||||
|
||||
const importButton = screen.getByText('common.overwriteAndImport')
|
||||
expect(importButton).not.toBeDisabled()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(importButton)
|
||||
// Flush the promise resolution from mockImportDSL
|
||||
await Promise.resolve()
|
||||
// Advance past the 300ms setTimeout in the component
|
||||
await vi.advanceTimersByTimeAsync(350)
|
||||
})
|
||||
fireEvent.click(importButton)
|
||||
|
||||
// Wait for error modal
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
|
||||
})
|
||||
}, { timeout: 500 })
|
||||
|
||||
// Click confirm button
|
||||
const confirmButton = screen.getByText('newApp.Confirm')
|
||||
@ -814,8 +800,6 @@ describe('UpdateDSLModal', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id')
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should show success notification after confirm completes', async () => {
|
||||
@ -1024,8 +1008,6 @@ describe('UpdateDSLModal', () => {
|
||||
})
|
||||
|
||||
it('should call handleCheckPluginDependencies after confirm', async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
|
||||
mockImportDSL.mockResolvedValue({
|
||||
id: 'import-id',
|
||||
status: DSLImportStatus.PENDING,
|
||||
@ -1043,27 +1025,19 @@ describe('UpdateDSLModal', () => {
|
||||
|
||||
const fileInput = screen.getByTestId('file-input')
|
||||
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
|
||||
fireEvent.change(fileInput, { target: { files: [file] } })
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.change(fileInput, { target: { files: [file] } })
|
||||
// Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask)
|
||||
await new Promise<void>(resolve => queueMicrotask(resolve))
|
||||
await waitFor(() => {
|
||||
const importButton = screen.getByText('common.overwriteAndImport')
|
||||
expect(importButton).not.toBeDisabled()
|
||||
})
|
||||
|
||||
const importButton = screen.getByText('common.overwriteAndImport')
|
||||
expect(importButton).not.toBeDisabled()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(importButton)
|
||||
// Flush the promise resolution from mockImportDSL
|
||||
await Promise.resolve()
|
||||
// Advance past the 300ms setTimeout in the component
|
||||
await vi.advanceTimersByTimeAsync(350)
|
||||
})
|
||||
fireEvent.click(importButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
|
||||
})
|
||||
}, { timeout: 500 })
|
||||
|
||||
const confirmButton = screen.getByText('newApp.Confirm')
|
||||
fireEvent.click(confirmButton)
|
||||
@ -1071,8 +1045,6 @@ describe('UpdateDSLModal', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true)
|
||||
})
|
||||
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should handle undefined imported_dsl_version and current_dsl_version', async () => {
|
||||
|
||||
@ -103,22 +103,15 @@ const MCPDetailContent: FC<Props> = ({
|
||||
return
|
||||
if (!detail)
|
||||
return
|
||||
try {
|
||||
const res = await authorizeMcp({
|
||||
provider_id: detail.id,
|
||||
})
|
||||
if (res.result === 'success')
|
||||
handleUpdateTools()
|
||||
const res = await authorizeMcp({
|
||||
provider_id: detail.id,
|
||||
})
|
||||
if (res.result === 'success')
|
||||
handleUpdateTools()
|
||||
|
||||
else if (res.authorization_url)
|
||||
openOAuthPopup(res.authorization_url, handleOAuthCallback)
|
||||
}
|
||||
catch {
|
||||
// On authorization error, refresh the parent component state
|
||||
// to update the connection status indicator
|
||||
onUpdate()
|
||||
}
|
||||
}, [onFirstCreate, isCurrentWorkspaceManager, detail, authorizeMcp, handleUpdateTools, handleOAuthCallback, onUpdate])
|
||||
else if (res.authorization_url)
|
||||
openOAuthPopup(res.authorization_url, handleOAuthCallback)
|
||||
}, [onFirstCreate, isCurrentWorkspaceManager, detail, authorizeMcp, handleUpdateTools, handleOAuthCallback])
|
||||
|
||||
const handleUpdate = useCallback(async (data: any) => {
|
||||
if (!detail)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import type { ToolNodeType, ToolVarInputs } from './types'
|
||||
import type { InputVar } from '@/app/components/workflow/types'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { capitalize } from 'es-toolkit/string'
|
||||
import { produce } from 'immer'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -26,12 +25,6 @@ import {
|
||||
} from '@/service/use-tools'
|
||||
import { canFindTool } from '@/utils'
|
||||
import { useWorkflowStore } from '../../store'
|
||||
import { normalizeJsonSchemaType } from './output-schema-utils'
|
||||
|
||||
const formatDisplayType = (output: Record<string, unknown>): string => {
|
||||
const normalizedType = normalizeJsonSchemaType(output) || 'Unknown'
|
||||
return capitalize(normalizedType)
|
||||
}
|
||||
|
||||
const useConfig = (id: string, payload: ToolNodeType) => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
@ -254,13 +247,20 @@ const useConfig = (id: string, payload: ToolNodeType) => {
|
||||
})
|
||||
}
|
||||
else {
|
||||
const normalizedType = normalizeJsonSchemaType(output)
|
||||
res.push({
|
||||
name: outputKey,
|
||||
type:
|
||||
normalizedType === 'array'
|
||||
? `Array[${output.items ? formatDisplayType(output.items) : 'Unknown'}]`
|
||||
: formatDisplayType(output),
|
||||
output.type === 'array'
|
||||
? `Array[${output.items?.type
|
||||
? output.items.type.slice(0, 1).toLocaleUpperCase()
|
||||
+ output.items.type.slice(1)
|
||||
: 'Unknown'
|
||||
}]`
|
||||
: `${output.type
|
||||
? output.type.slice(0, 1).toLocaleUpperCase()
|
||||
+ output.type.slice(1)
|
||||
: 'Unknown'
|
||||
}`,
|
||||
description: output.description,
|
||||
})
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@
|
||||
"uglify-embed": "node ./bin/uglify-embed",
|
||||
"i18n:check": "tsx ./scripts/check-i18n.js",
|
||||
"test": "vitest run",
|
||||
"test:coverage": "vitest run --coverage --reporter=dot --silent=passed-only",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test:watch": "vitest --watch",
|
||||
"analyze-component": "node ./scripts/analyze-component.js",
|
||||
"refactor-component": "node ./scripts/refactor-component.js",
|
||||
@ -233,8 +233,7 @@
|
||||
"uglify-js": "3.19.3",
|
||||
"vite": "7.3.1",
|
||||
"vite-tsconfig-paths": "6.0.4",
|
||||
"vitest": "4.0.17",
|
||||
"vitest-canvas-mock": "1.1.3"
|
||||
"vitest": "4.0.17"
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
|
||||
26
web/pnpm-lock.yaml
generated
26
web/pnpm-lock.yaml
generated
@ -579,9 +579,6 @@ importers:
|
||||
vitest:
|
||||
specifier: 4.0.17
|
||||
version: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
|
||||
vitest-canvas-mock:
|
||||
specifier: 1.1.3
|
||||
version: 1.1.3(vitest@4.0.17)
|
||||
|
||||
packages:
|
||||
|
||||
@ -4005,9 +4002,6 @@ packages:
|
||||
engines: {node: '>=4'}
|
||||
hasBin: true
|
||||
|
||||
cssfontparser@1.2.1:
|
||||
resolution: {integrity: sha512-6tun4LoZnj7VN6YeegOVb67KBX/7JJsqvj+pv3ZA7F878/eN33AbGa5b/S/wXxS/tcp8nc40xRUrsPlxIyNUPg==}
|
||||
|
||||
cssstyle@5.3.7:
|
||||
resolution: {integrity: sha512-7D2EPVltRrsTkhpQmksIu+LxeWAIEk6wRDMJ1qljlv+CKHJM+cJLlfhWIzNA44eAsHXSNe3+vO6DW1yCYx8SuQ==}
|
||||
engines: {node: '>=20'}
|
||||
@ -5757,9 +5751,6 @@ packages:
|
||||
monaco-editor@0.55.1:
|
||||
resolution: {integrity: sha512-jz4x+TJNFHwHtwuV9vA9rMujcZRb0CEilTEwG2rRSpe/A7Jdkuj8xPKttCgOh+v/lkHy7HsZ64oj+q3xoAFl9A==}
|
||||
|
||||
moo-color@1.0.3:
|
||||
resolution: {integrity: sha512-i/+ZKXMDf6aqYtBhuOcej71YSlbjT3wCO/4H1j8rPvxDJEifdwgg5MaFyu6iYAT8GBZJg2z0dkgK4YMzvURALQ==}
|
||||
|
||||
mri@1.2.0:
|
||||
resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==}
|
||||
engines: {node: '>=4'}
|
||||
@ -7225,11 +7216,6 @@ packages:
|
||||
yaml:
|
||||
optional: true
|
||||
|
||||
vitest-canvas-mock@1.1.3:
|
||||
resolution: {integrity: sha512-zlKJR776Qgd+bcACPh0Pq5MG3xWq+CdkACKY/wX4Jyija0BSz8LH3aCCgwFKYFwtm565+050YFEGG9Ki0gE/Hw==}
|
||||
peerDependencies:
|
||||
vitest: ^3.0.0 || ^4.0.0
|
||||
|
||||
vitest@4.0.17:
|
||||
resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==}
|
||||
engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0}
|
||||
@ -11288,8 +11274,6 @@ snapshots:
|
||||
|
||||
cssesc@3.0.0: {}
|
||||
|
||||
cssfontparser@1.2.1: {}
|
||||
|
||||
cssstyle@5.3.7:
|
||||
dependencies:
|
||||
'@asamuzakjp/css-color': 4.1.1
|
||||
@ -13589,10 +13573,6 @@ snapshots:
|
||||
dompurify: 3.2.7
|
||||
marked: 14.0.0
|
||||
|
||||
moo-color@1.0.3:
|
||||
dependencies:
|
||||
color-name: 1.1.4
|
||||
|
||||
mri@1.2.0: {}
|
||||
|
||||
mrmime@2.0.1: {}
|
||||
@ -15222,12 +15202,6 @@ snapshots:
|
||||
tsx: 4.21.0
|
||||
yaml: 2.8.2
|
||||
|
||||
vitest-canvas-mock@1.1.3(vitest@4.0.17):
|
||||
dependencies:
|
||||
cssfontparser: 1.2.1
|
||||
moo-color: 1.0.3
|
||||
vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
|
||||
|
||||
vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2):
|
||||
dependencies:
|
||||
'@vitest/expect': 4.0.17
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { App, AppCategory } from '@/models/explore'
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { fetchAppList, fetchBanners, fetchInstalledAppList, getAppAccessModeByAppId, uninstallApp, updatePinStatus } from './explore'
|
||||
import { AppSourceType, fetchAppMeta, fetchAppParams } from './share'
|
||||
@ -13,8 +14,9 @@ type ExploreAppListData = {
|
||||
}
|
||||
|
||||
export const useExploreAppList = () => {
|
||||
const locale = useLocale()
|
||||
return useQuery<ExploreAppListData>({
|
||||
queryKey: [NAME_SPACE, 'appList'],
|
||||
queryKey: [NAME_SPACE, 'appList', locale],
|
||||
queryFn: async () => {
|
||||
const { categories, recommended_apps } = await fetchAppList()
|
||||
return {
|
||||
|
||||
@ -8,7 +8,7 @@ export default mergeConfig(viteConfig, defineConfig({
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: ['json', 'json-summary'],
|
||||
reporter: ['text', 'json', 'json-summary'],
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import { act, cleanup } from '@testing-library/react'
|
||||
import { mockAnimationsApi, mockResizeObserver } from 'jsdom-testing-mocks'
|
||||
import '@testing-library/jest-dom/vitest'
|
||||
import 'vitest-canvas-mock'
|
||||
|
||||
mockResizeObserver()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user