Compare commits

..

2 Commits

Author SHA1 Message Date
225238b4b2 Update dev/ast-grep/rules/remove-nullable-arg.yaml
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2025-11-05 03:09:34 +08:00
c4ea3e47fd refactor: enforce typed String mapped columns 2025-11-05 03:09:34 +08:00
527 changed files with 2290 additions and 27884 deletions

View File

@ -53,6 +53,8 @@ jobs:
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# Rewrite SQLAlchemy with Type Annotations
uvx --from ast-grep-cli sg scan -r dev/ast-grep/rules/remove-nullable-arg.yaml api/models -U
- name: mdformat
run: |

View File

@ -30,9 +30,6 @@ INTERNAL_FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Collaboration mode toggle
ENABLE_COLLABORATION_MODE=false
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

View File

@ -1,4 +1,3 @@
import os
import sys
@ -9,16 +8,10 @@ def is_db_command():
# create app
celery = None
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
@ -29,15 +22,8 @@ else:
from app_factory import create_app
socketio_app, flask_app = create_app()
app = flask_app
celery = flask_app.extensions["celery"]
app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()
app.run(host="0.0.0.0", port=5001)

View File

@ -31,22 +31,14 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> tuple[any, DifyApp]:
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
import socketio
from extensions.ext_socketio import sio
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return socketio_app, app
return app
def initialize_extensions(app: DifyApp):

View File

@ -1601,7 +1601,7 @@ def transform_datasource_credentials():
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jinareader",
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@ -871,16 +871,6 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
@ -1093,13 +1083,6 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
)
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
@ -1159,13 +1142,6 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantSelfTaskQueueConfig(BaseSettings):
TENANT_SELF_TASK_QUEUE_PULL_SIZE: int = Field(
description="Default batch size for tenant self task queue pull operations",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1190,13 +1166,11 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantSelfTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
CollaborationConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,

View File

@ -8,11 +8,6 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@ -65,46 +60,19 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@ -119,13 +87,6 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@ -133,150 +94,7 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="deepseek-chat,deepseek-reasoner",
"text-davinci-003",
)
@ -326,30 +144,16 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings):
"""
@ -446,8 +250,5 @@ class HostedServiceConfig(
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@ -22,11 +22,6 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

View File

@ -58,13 +58,11 @@ from .app import (
mcp_server,
message,
model_config,
online_user,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@ -108,12 +106,10 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -147,7 +143,6 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -201,7 +196,6 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"version",
"website",
"workflow",

View File

@ -16,7 +16,7 @@ from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from models.model import App, InstalledApp, RecommendedApp
def admin_required(view: Callable[P, R]):
@ -52,8 +52,6 @@ class InsertExploreAppListApi(Resource):
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
"can_trial": fields.Boolean(required=True, description="Can trial"),
"trial_limit": fields.Integer(required=True, description="Trial limit"),
},
)
)
@ -73,8 +71,6 @@ class InsertExploreAppListApi(Resource):
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
@ -112,20 +108,6 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -140,20 +122,6 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = args["category"]
recommended_app.position = args["position"]
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -199,83 +167,7 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBanner(Resource):
@api.doc("insert_explore_banner")
@api.doc(description="Insert an explore banner")
@api.expect(
api.model(
"InsertExploreBannerRequest",
{
"content": fields.String(required=True, description="Banner content"),
"link": fields.String(required=True, description="Banner link"),
"sort": fields.Integer(required=True, description="Banner sort"),
},
)
)
@api.response(200, "Banner inserted successfully")
@admin_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("title", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, required=True, nullable=False, location="json")
parser.add_argument("img-src", type=str, required=True, nullable=False, location="json")
parser.add_argument("language", type=str, required=True, nullable=False, location="json")
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
content = {
"category": args["category"],
"title": args["title"],
"description": args["description"],
"img-src": args["img-src"],
}
if not args["language"]:
args["language"] = "en-US"
banner = ExporleBanner(
content=content,
link=args["link"],
sort=args["sort"],
language=args["language"],
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 200
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBanner(Resource):
@api.doc("delete_explore_banner")
@api.doc(description="Delete an explore banner")
@api.response(204, "Banner deleted successfully")
@admin_required
@only_edition_cloud
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -115,9 +115,3 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400

View File

@ -1,339 +0,0 @@
import json
import time
from werkzeug.wrappers import Request as WerkzeugRequest
from extensions.ext_redis import redis_client
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
from services.account_service import AccountService
SESSION_STATE_TTL_SECONDS = 3600
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
WS_SID_MAP_PREFIX = "ws_sid_map:"
def _workflow_key(workflow_id: str) -> str:
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
def _leader_key(workflow_id: str) -> str:
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
def _sid_key(sid: str) -> str:
return f"{WS_SID_MAP_PREFIX}{sid}"
def _refresh_session_state(workflow_id: str, sid: str) -> None:
"""
Refresh TTLs for workflow + session keys so healthy sessions do not linger forever after crashes.
"""
workflow_key = _workflow_key(workflow_id)
sid_key = _sid_key(sid)
if redis_client.exists(workflow_key):
redis_client.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
if redis_client.exists(sid_key):
redis_client.expire(sid_key, SESSION_STATE_TTL_SECONDS)
@sio.on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
token = None
if auth and isinstance(auth, dict):
token = auth.get("token")
if not token:
try:
request_environ = WerkzeugRequest(environ)
token = extract_access_token(request_environ)
except Exception:
token = None
if not token:
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
return False
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
return True
except Exception:
return False
@sio.on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
session = sio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return {"msg": "unauthorized"}, 401
# Each session is stored independently with sid as key
session_info = {
"user_id": user_id,
"username": session.get("username", "Unknown"),
"avatar": session.get("avatar", None),
"sid": sid,
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
}
workflow_key = _workflow_key(workflow_id)
# Store session info with sid as key
redis_client.hset(workflow_key, sid, json.dumps(session_info))
redis_client.set(
_sid_key(sid),
json.dumps({"workflow_id": workflow_id, "user_id": user_id}),
ex=SESSION_STATE_TTL_SECONDS,
)
_refresh_session_state(workflow_id, sid)
# Leader election: first session becomes the leader
leader_sid = get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid
sio.enter_room(sid, workflow_id)
broadcast_online_users(workflow_id)
# Notify this session of their leader status
sio.emit("status", {"isLeader": is_leader}, room=sid)
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@sio.on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
mapping = redis_client.get(_sid_key(sid))
if mapping:
data = json.loads(mapping)
workflow_id = data["workflow_id"]
# Remove this specific session
redis_client.hdel(_workflow_key(workflow_id), sid)
redis_client.delete(_sid_key(sid))
# Handle leader re-election if the leader session disconnected
handle_leader_disconnect(workflow_id, sid)
broadcast_online_users(workflow_id)
def _clear_session_state(workflow_id: str, sid: str) -> None:
redis_client.hdel(_workflow_key(workflow_id), sid)
redis_client.delete(_sid_key(sid))
def _is_session_active(workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not sio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not redis_client.hexists(_workflow_key(workflow_id), sid):
return False
if not redis_client.exists(_sid_key(sid)):
return False
return True
def get_or_set_leader(workflow_id: str, sid: str) -> str:
"""
Get current leader session or set this session as leader if no valid leader exists.
Returns the leader session id (sid).
"""
raw_leader = redis_client.get(_leader_key(workflow_id))
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
leader_replaced = False
if current_leader and not _is_session_active(workflow_id, current_leader):
_clear_session_state(workflow_id, current_leader)
redis_client.delete(_leader_key(workflow_id))
current_leader = None
leader_replaced = True
if not current_leader:
redis_client.set(_leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) # Expire in 1 hour
if leader_replaced:
broadcast_leader_change(workflow_id, sid)
return sid
return current_leader
def handle_leader_disconnect(workflow_id, disconnected_sid):
"""
Handle leader re-election when a session disconnects.
If the disconnected session was the leader, elect a new leader from remaining sessions.
"""
current_leader = redis_client.get(_leader_key(workflow_id))
if current_leader:
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
if current_leader == disconnected_sid:
# Leader session disconnected, elect a new leader
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
if sessions_json:
# Get the first remaining session as new leader
new_leader_sid = list(sessions_json.keys())[0]
if isinstance(new_leader_sid, bytes):
new_leader_sid = new_leader_sid.decode("utf-8")
redis_client.set(_leader_key(workflow_id), new_leader_sid, ex=SESSION_STATE_TTL_SECONDS)
# Notify all sessions about the new leader
broadcast_leader_change(workflow_id, new_leader_sid)
else:
# No sessions left, remove leader
redis_client.delete(_leader_key(workflow_id))
def broadcast_leader_change(workflow_id, new_leader_sid):
"""
Broadcast leader change to all sessions in the workflow.
"""
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
for sid, session_info_json in sessions_json.items():
try:
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
is_leader = sid_str == new_leader_sid
# Emit to each session whether they are the new leader
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
except Exception:
continue
def get_current_leader(workflow_id):
"""
Get the current leader for a workflow.
"""
leader = redis_client.get(_leader_key(workflow_id))
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
def broadcast_online_users(workflow_id):
"""
Broadcast online users to the workflow room.
Each session is shown as a separate user (even if same person has multiple tabs).
"""
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
users = []
for sid, session_info_json in sessions_json.items():
try:
session_info = json.loads(session_info_json)
# Each session appears as a separate "user" in the UI
users.append(
{
"user_id": session_info["user_id"],
"username": session_info["username"],
"avatar": session_info.get("avatar"),
"sid": session_info["sid"],
"connected_at": session_info.get("connected_at"),
}
)
except Exception:
continue
# Sort by connection time to maintain consistent order
users.sort(key=lambda x: x.get("connected_at") or 0)
# Get current leader session
leader_sid = get_current_leader(workflow_id)
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
@sio.on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouse_move
2. vars_and_features_update
3. sync_request (ask leader to update graph)
4. app_state_update
5. mcp_server_update
6. workflow_update
7. comments_update
8. node_panel_presence
"""
mapping = redis_client.get(_sid_key(sid))
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
user_id = mapping_data["user_id"]
_refresh_session_state(workflow_id, sid)
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
sio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}
@sio.on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
mapping = redis_client.get(_sid_key(sid))
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
_refresh_session_state(workflow_id, sid)
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}

View File

@ -5,12 +5,10 @@ from typing import cast
from flask import abort, request
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from pydantic_core import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
@ -23,9 +21,7 @@ from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -103,22 +99,10 @@ class DraftWorkflowApi(Resource):
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
},
)
)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
@ -138,8 +122,6 @@ class DraftWorkflowApi(Resource):
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("force_upload", type=bool, required=False, default=False, location="json")
.add_argument("memory_blocks", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type:
@ -157,8 +139,6 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"memory_blocks": data.get("memory_blocks"),
"force_upload": data.get("force_upload", False),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -175,10 +155,6 @@ class DraftWorkflowApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
memory_blocks_list = args.get("memory_blocks") or []
from core.memory.entities import MemoryBlockSpec
memory_blocks = [MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
@ -187,13 +163,9 @@ class DraftWorkflowApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
force_upload=args.get("force_upload", False),
memory_blocks=memory_blocks,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
except ValidationError as e:
return {"message": str(e)}, 400
return {
"result": "success",
@ -762,45 +734,6 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
parser = reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, location="json")
args = parser.parse_args()
features = args.get("features")
# Update draft workflow features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows")
@ -982,105 +915,3 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
class WorkflowOnlineUsersApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workflow_ids", type=str, required=True, location="args")
args = parser.parse_args()
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
return {"data": results}
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
WorkflowFeaturesApi,
"/apps/<uuid:app_id>/workflows/draft/features",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")

View File

@ -1,240 +0,0 @@
import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import account_with_role_fields
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_basic_fields, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_create_fields)
def post(self, app_model: App):
"""Create a new workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("position_x", type=float, required=True, location="json")
parser.add_argument("position_y", type=float, required=True, location="json")
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_detail_fields)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_update_fields)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("position_x", type=float, required=False, location="json")
parser.add_argument("position_y", type=float, required=False, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_resolve_fields)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_create_fields)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=args.content,
created_by=current_user.id,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_update_fields)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
reply = WorkflowCommentService.update_reply(
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
)
return reply
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"users": members}
# Register API routes
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
api.add_resource(
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
)
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account, App, AppMode
from models.workflow import WorkflowDraftVariable
@ -355,7 +355,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@ -448,35 +448,8 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables")
@api.doc(description="Get system variables for workflow")
@ -526,44 +499,3 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("environment_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

View File

@ -23,15 +23,6 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = (
db.session.query(App)
.where(App.id == app_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@ -71,44 +62,3 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs["app_id"]
app_model = _load_app_model_with_trial(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -25,13 +25,10 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.passport import PassportService
from libs.token import (
check_csrf_token,
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
@ -292,18 +289,3 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
res = True
try:
validated = PassportService().verify(token=token)
check_csrf_token(request=request, user_id=validated.get("user_id"))
except:
res = False
return {"logged_in": res}

View File

@ -1,43 +0,0 @@
from flask import request
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"id": banner.id,
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,25 +29,3 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -27,7 +27,6 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -1,514 +0,0 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common import fields
from controllers.common.fields import build_site_model
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NeedAddIdsError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model_with_trial
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.service_api import service_api_ns
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.dataset_service import DatasetService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
Run workflow
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialAppWorkflowTaskStopApi(TrialAppResource):
def post(self, trial_app, task_id: str):
"""
Stop workflow task
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(fields.parameters_fields)
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
tenant_id = app_model.tenant_id
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
else:
raise NeedAddIdsError()
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@ -2,16 +2,14 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from models import InstalledApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -73,59 +71,6 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -135,13 +80,3 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -32,7 +32,6 @@ from controllers.console.wraps import (
only_edition_cloud,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
@ -129,17 +128,6 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="args")
args = parser.parse_args()
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
return {"avatar_url": avatar_url}
@setup_required
@login_required
@account_initialization_required

View File

@ -51,9 +51,6 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
"next_credit_reset_date": fields.Integer,
}
tenants_fields = {

View File

@ -19,7 +19,6 @@ from .app import (
annotation,
app,
audio,
chatflow_memory,
completion,
conversation,
file,
@ -41,7 +40,6 @@ __all__ = [
"annotation",
"app",
"audio",
"chatflow_memory",
"completion",
"conversation",
"dataset",

View File

@ -1,109 +0,0 @@
from flask_restx import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.runtime.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument("node_id", type=str | None, required=False, default=None)
parser.add_argument("update", type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {"error": "Invalid update"}, 400
if not workflow:
return {"error": "Workflow not found"}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
if not memory_spec:
return {"error": "Memory not found"}, 404
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False,
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return "", 204
class MemoryDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get("id")
if memory_id:
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
return "", 204
else:
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
return "", 200
api.add_resource(MemoryListApi, "/memories")
api.add_resource(MemoryEditApi, "/memory-edit")
api.add_resource(MemoryDeleteApi, "/memories")

View File

@ -67,7 +67,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@ -76,6 +75,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@ -90,28 +90,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)

View File

@ -18,7 +18,6 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
from . import (
app,
audio,
chatflow_memory,
completion,
conversation,
feature,
@ -40,7 +39,6 @@ __all__ = [
"app",
"audio",
"bp",
"chatflow_memory",
"completion",
"conversation",
"feature",

View File

@ -1,108 +0,0 @@
from flask_restx import reqparse
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.memory.entities import MemoryBlock, MemoryCreatedBy
from core.workflow.runtime.variable_pool import VariablePool
from models import App, EndUser
from services.chatflow_memory_service import ChatflowMemoryService
from services.workflow_service import WorkflowService
class MemoryListApi(WebApiResource):
def get(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
parser.add_argument("memory_id", required=False, type=str | None, default=None)
parser.add_argument("version", required=False, type=int | None, default=None)
args = parser.parse_args()
conversation_id: str | None = args.get("conversation_id")
memory_id = args.get("memory_id")
version = args.get("version")
if conversation_id:
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
)
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
)
result = [*result, *session_memories]
else:
result = ChatflowMemoryService.get_persistent_memories(
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
)
if memory_id:
result = [it for it in result if it.spec.id == memory_id]
return [it for it in result if it.spec.end_user_visible]
class MemoryEditApi(WebApiResource):
def put(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True)
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
parser.add_argument("node_id", type=str | None, required=False, default=None)
parser.add_argument("update", type=str, required=True)
args = parser.parse_args()
workflow = WorkflowService().get_published_workflow(app_model)
update = args.get("update")
conversation_id = args.get("conversation_id")
node_id = args.get("node_id")
if not isinstance(update, str):
return {"error": "Update must be a string"}, 400
if not workflow:
return {"error": "Workflow not found"}, 404
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
if not memory_spec:
return {"error": "Memory not found"}, 404
if not memory_spec.end_user_editable:
return {"error": "Memory not editable"}, 403
# First get existing memory
existing_memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
created_by=MemoryCreatedBy(end_user_id=end_user.id),
conversation_id=conversation_id,
node_id=node_id,
is_draft=False,
)
# Create updated memory instance with incremented version
updated_memory = MemoryBlock(
spec=existing_memory.spec,
tenant_id=existing_memory.tenant_id,
app_id=existing_memory.app_id,
conversation_id=existing_memory.conversation_id,
node_id=existing_memory.node_id,
value=update, # New value
version=existing_memory.version + 1, # Increment version for update
edited_by_user=True,
created_by=existing_memory.created_by,
)
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
return "", 204
class MemoryDeleteApi(WebApiResource):
def delete(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=False, default=None)
args = parser.parse_args()
memory_id = args.get("id")
if memory_id:
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
return "", 204
else:
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
return "", 200
api.add_resource(MemoryListApi, "/memories")
api.add_resource(MemoryEditApi, "/memory-edit")
api.add_resource(MemoryDeleteApi, "/memories")

View File

@ -1,11 +1,10 @@
import logging
import time
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -21,8 +20,6 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -30,7 +27,6 @@ from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphRunSucceededEvent
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -43,8 +39,6 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.chatflow_memory_service import ChatflowMemoryService
logger = logging.getLogger(__name__)
@ -87,10 +81,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
ChatflowMemoryService.wait_for_sync_memory_completion(
workflow=self._workflow, conversation_id=self.conversation.id
)
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -153,7 +143,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=conversation_variables,
memory_blocks=self._fetch_memory_blocks(),
)
# init graph
@ -217,31 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
for event in generator:
self._handle_event(workflow_entry, event)
try:
self._check_app_memory_updates(variable_pool)
except Exception as e:
logger.exception("Failed to check app memory updates", exc_info=e)
@override
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
super()._handle_event(workflow_entry, event)
if isinstance(event, GraphRunSucceededEvent):
workflow_outputs = event.outputs
if not workflow_outputs:
logger.warning("Chatflow output is empty.")
return
assistant_message = workflow_outputs.get("answer")
if not assistant_message:
logger.warning("Chatflow output does not contain 'answer'.")
return
if not isinstance(assistant_message, str):
logger.warning("Chatflow output 'answer' is not a string.")
return
try:
self._sync_conversation_to_chatflow_tables(assistant_message)
except Exception as e:
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
def handle_input_moderation(
self,
app_record: App,
@ -439,67 +403,3 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Return combined list
return existing_variables + new_variables
def _fetch_memory_blocks(self) -> Mapping[str, str]:
"""fetch all memory blocks for current app"""
memory_blocks_dict: MutableMapping[str, str] = {}
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
conversation_id = self.conversation.id
memory_block_specs = self._workflow.memory_blocks
# Get runtime memory values
memories = ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=memory_block_specs,
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=self._get_created_by(),
)
# Build memory_id -> value mapping
for memory in memories:
if memory.spec.scope == MemoryScope.APP:
# App level: use memory_id directly
memory_blocks_dict[memory.spec.id] = memory.value
else: # NODE scope
node_id = memory.node_id
if not node_id:
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
continue
key = f"{node_id}.{memory.spec.id}"
memory_blocks_dict[key] = memory.value
return memory_blocks_dict
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
ChatflowHistoryService.save_app_message(
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id,
)
ChatflowHistoryService.save_app_message(
prompt_message=AssistantPromptMessage(content=assistant_message),
conversation_id=self.conversation.id,
app_id=self._workflow.app_id,
tenant_id=self._workflow.tenant_id,
)
def _check_app_memory_updates(self, variable_pool: VariablePool):
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
ChatflowMemoryService.update_app_memory_if_needed(
workflow=self._workflow,
conversation_id=self.conversation.id,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_created_by(),
)
def _get_created_by(self) -> MemoryCreatedBy:
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
else:
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)

View File

@ -42,13 +42,18 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)

View File

@ -1,13 +0,0 @@
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: list[str]

View File

@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
@ -272,6 +271,7 @@ class MCPProviderEntity(BaseModel):
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
from core.tools.utils.encryption import create_provider_encrypter
if not data:
return {}

View File

@ -6,7 +6,10 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@ -18,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
""")
"""
)
return runner_script

View File

@ -6,7 +6,9 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
import json
from base64 import b64decode

View File

@ -56,9 +56,6 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.moderation_config = self.init_moderation_config()
@ -131,7 +128,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = 0
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@ -159,49 +156,18 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
if len(quotas) > 0:
@ -219,66 +185,6 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@ -14,12 +14,10 @@ from core.llm_generator.prompts import (
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
MEMORY_UPDATE_PROMPT,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.memory.entities import MemoryBlock, MemoryBlockSpec
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@ -30,7 +28,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@ -565,35 +562,3 @@ class LLMGenerator:
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def update_memory_block(
tenant_id: str,
visible_history: Sequence[tuple[str, str]],
variable_pool: VariablePool,
memory_block: MemoryBlock,
memory_spec: MemoryBlockSpec,
) -> str:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=memory_spec.model.provider,
model=memory_spec.model.name,
model_type=ModelType.LLM,
)
formatted_history = ""
for sender, message in visible_history:
formatted_history += f"{sender}: {message}\n"
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
inputs={
"formatted_history": formatted_history,
"current_value": memory_block.value,
"instruction": filled_instruction,
}
)
llm_result = model_instance.invoke_llm(
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
model_parameters=memory_spec.model.completion_params,
stream=False,
)
return llm_result.message.get_text_content()

View File

@ -422,18 +422,3 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
MEMORY_UPDATE_PROMPT = """
Based on the following conversation history, update the memory content:
Conversation history:
{{formatted_history}}
Current memory:
{{current_value}}
Update instruction:
{{instruction}}
Please output only the updated memory content, no other text like greeting:
"""

View File

@ -1,119 +0,0 @@
from __future__ import annotations
from enum import StrEnum
from typing import TYPE_CHECKING, Optional
from uuid import uuid4
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from core.app.app_config.entities import ModelConfig
class MemoryScope(StrEnum):
"""Memory scope determined by node_id field"""
APP = "app" # node_id is None
NODE = "node" # node_id is not None
class MemoryTerm(StrEnum):
"""Memory term determined by conversation_id field"""
SESSION = "session" # conversation_id is not None
PERSISTENT = "persistent" # conversation_id is None
class MemoryStrategy(StrEnum):
ON_TURNS = "on_turns"
class MemoryScheduleMode(StrEnum):
SYNC = "sync"
ASYNC = "async"
class MemoryBlockSpec(BaseModel):
"""Memory block specification for workflow configuration"""
id: str = Field(
default_factory=lambda: str(uuid4()),
description="Unique identifier for the memory block",
)
name: str = Field(description="Display name of the memory block")
description: str = Field(default="", description="Description of the memory block")
template: str = Field(description="Initial template content for the memory")
instruction: str = Field(description="Instructions for updating the memory")
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
update_turns: int = Field(gt=0, description="Number of turns between updates")
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
model: ModelConfig = Field(description="Model configuration for memory updates")
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
class MemoryCreatedBy(BaseModel):
end_user_id: str | None = None
account_id: str | None = None
class MemoryBlock(BaseModel):
"""Runtime memory block instance
Design Rules:
- app_id = None: Global memory (future feature, not implemented yet)
- app_id = str: App-specific memory
- conversation_id = None: Persistent memory (cross-conversation)
- conversation_id = str: Session memory (conversation-specific)
- node_id = None: App-level scope
- node_id = str: Node-level scope
These rules implicitly determine scope and term without redundant storage.
"""
spec: MemoryBlockSpec
tenant_id: str
value: str
app_id: str
conversation_id: Optional[str] = None
node_id: Optional[str] = None
edited_by_user: bool = False
created_by: MemoryCreatedBy
version: int = Field(description="Memory block version number")
class MemoryValueData(BaseModel):
value: str
edited_by_user: bool = False
class ChatflowConversationMetadata(BaseModel):
"""Metadata for chatflow conversation with visible message count"""
type: str = "mutable_visible_window"
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
class MemoryBlockWithConversation(MemoryBlock):
"""MemoryBlock with optional conversation metadata for session memories"""
conversation_metadata: ChatflowConversationMetadata = Field(
description="Conversation metadata, only present for session memories"
)
@classmethod
def from_memory_block(
cls,
memory_block: MemoryBlock,
conversation_metadata: ChatflowConversationMetadata
) -> MemoryBlockWithConversation:
"""Create MemoryBlockWithConversation from MemoryBlock"""
return cls(
spec=memory_block.spec,
tenant_id=memory_block.tenant_id,
value=memory_block.value,
app_id=memory_block.app_id,
conversation_id=memory_block.conversation_id,
node_id=memory_block.node_id,
edited_by_user=memory_block.edited_by_user,
created_by=memory_block.created_by,
version=memory_block.version,
conversation_metadata=conversation_metadata
)

View File

@ -1,6 +0,0 @@
class MemorySyncTimeoutError(Exception):
def __init__(self, app_id: str, conversation_id: str):
self.app_id = app_id
self.conversation_id = conversation_id
self.message = "Memory synchronization timeout after 50 seconds"
super().__init__(self.message)

View File

@ -45,12 +45,6 @@ class MemoryConfig(BaseModel):
enabled: bool
size: int | None = None
mode: Literal["linear", "block"] | None = "linear"
block_id: list[str] | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
@property
def is_block_mode(self) -> bool:
return self.mode == "block" and bool(self.block_id)

View File

@ -618,18 +618,18 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists
if quota.quota_type not in provider_quota_to_provider_record_dict:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
provider_type=ProviderType.SYSTEM,
quota_type=ProviderQuotaType.TRIAL,
quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
)
@ -641,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota.quota_type,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == ProviderQuotaType.TRIAL,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -652,7 +652,7 @@ class ProviderManager:
existed_provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict
@ -912,22 +912,6 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@ -948,36 +932,16 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@ -289,8 +289,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
# nr: person name, ns: place name, nt: organization name
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
current_entity += word
else:
if current_entity:

View File

@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase supports vector dimensions in range [1, 16000]
# Vastbase 支持的向量维度取值范围为 [1,16000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -39,13 +39,11 @@ class WeaviateConfig(BaseModel):
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@ -90,22 +88,9 @@ class WeaviateVector(BaseVector):
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python verions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
@ -447,7 +432,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

@ -1,89 +0,0 @@
import json
from dataclasses import dataclass
from typing import Any
from extensions.ext_redis import redis_client
TASK_WRAPPER_PREFIX = "__WRAPPER__:"
@dataclass
class TaskWrapper:
data: Any
def serialize(self) -> str:
return json.dumps(self.data, ensure_ascii=False)
@classmethod
def deserialize(cls, serialized_data: str) -> 'TaskWrapper':
data = json.loads(serialized_data)
return cls(data)
class TenantSelfTaskQueue:
"""
Simple queue for tenant self tasks, used for tenant self task isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
DEFAULT_TASK_TTL = 60 * 60
def __init__(self, tenant_id: str, unique_key: str):
self.tenant_id = tenant_id
self.unique_key = unique_key
self.queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self.task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self.task_key)
def set_task_waiting_time(self, ttl: int | None = None):
ttl = ttl or self.DEFAULT_TASK_TTL
redis_client.setex(self.task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self.task_key)
def push_tasks(self, tasks: list):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization, add prefix for identification
wrapper = TaskWrapper(task)
serialized_data = wrapper.serialize()
serialized_tasks.append(f"{TASK_WRAPPER_PREFIX}{serialized_data}")
redis_client.lpush(self.queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> list:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self.queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode('utf-8')
# Check if use TaskWrapper or not
if serialized_task.startswith(TASK_WRAPPER_PREFIX):
try:
wrapper_data = serialized_task[len(TASK_WRAPPER_PREFIX):]
wrapper = TaskWrapper.deserialize(wrapper_data)
tasks.append(wrapper.data)
except (json.JSONDecodeError, TypeError, ValueError):
tasks.append(serialized_task)
else:
tasks.append(serialized_task)
return tasks
def get_next_task(self):
tasks = self.pull_tasks(1)
return tasks[0] if tasks else None

View File

@ -210,13 +210,12 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(

View File

@ -129,7 +129,6 @@ class ToolInvokeMessage(BaseModel):
class JsonMessage(BaseModel):
json_object: dict
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
blob: bytes

View File

@ -245,9 +245,6 @@ class ToolEngine:
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),

View File

@ -117,7 +117,7 @@ class WorkflowTool(Tool):
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs, suppress_output=True)
yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:

View File

@ -203,48 +203,6 @@ class ArrayFileSegment(ArraySegment):
return ""
class VersionedMemoryValue(BaseModel):
current_value: str = None # type: ignore
versions: Mapping[str, str] = {}
model_config = ConfigDict(frozen=True)
def add_version(
self,
new_value: str,
version_name: str | None = None
) -> "VersionedMemoryValue":
if version_name is None:
version_name = str(len(self.versions) + 1)
if version_name in self.versions:
raise ValueError(f"Version '{version_name}' already exists.")
self.current_value = new_value
return VersionedMemoryValue(
current_value=new_value,
versions={
version_name: new_value,
**self.versions,
}
)
class VersionedMemorySegment(Segment):
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
value: VersionedMemoryValue = None # type: ignore
@property
def text(self) -> str:
return self.value.current_value
@property
def log(self) -> str:
return self.value.current_value
@property
def markdown(self) -> str:
return self.value.current_value
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool]
@ -290,7 +248,6 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[VersionedMemorySegment, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -41,8 +41,6 @@ class SegmentType(StrEnum):
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
VERSIONED_MEMORY = "versioned_memory"
NONE = "none"
GROUP = "group"

View File

@ -22,7 +22,6 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -107,10 +106,6 @@ class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
pass
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
@ -166,7 +161,6 @@ VariableUnion: TypeAlias = Annotated[
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
| Annotated[VersionedMemoryVariable, Tag(SegmentType.VERSIONED_MEMORY)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -1,5 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -136,36 +136,21 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
session.execute(stmt)
session.commit()
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -7,15 +7,11 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -77,8 +73,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, Variabl
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from models import UserFrom, Workflow
from models.engine import db
from . import llm_utils
from .entities import (
@ -323,11 +317,6 @@ class LLMNode(Node):
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
try:
self._handle_chatflow_memory(result_text, variable_pool)
except Exception as e:
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@ -1239,77 +1228,6 @@ class LLMNode(Node):
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
if not self._node_data.memory or self._node_data.memory.mode != "block":
return
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
if not conversation_id_segment:
raise ValueError("Conversation ID not found in variable pool.")
conversation_id = conversation_id_segment.text
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not user_query_segment:
raise ValueError("User query not found in variable pool.")
user_query = user_query_segment.text
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from services.chatflow_history_service import ChatflowHistoryService
ChatflowHistoryService.save_node_message(
prompt_message=(UserPromptMessage(content=user_query)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id,
)
ChatflowHistoryService.save_node_message(
prompt_message=(AssistantPromptMessage(content=llm_output)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id,
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(Workflow.tenant_id == self.tenant_id, Workflow.app_id == self.app_id)
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = self.invoke_from == InvokeFrom.DEBUGGER
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context(),
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:
return MemoryCreatedBy(account_id=self.user_id)
else:
return MemoryCreatedBy(end_user_id=self.user_id)
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@ -9,12 +9,11 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment, VersionedMemoryValue
from core.variables.variables import RAGPipelineVariableInput, VariableUnion, VersionedMemoryVariable
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
MEMORY_BLOCK_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
@ -57,10 +56,6 @@ class VariablePool(BaseModel):
description="RAG pipeline variables.",
default_factory=list,
)
memory_blocks: Mapping[str, str] = Field(
description="Memory blocks.",
default_factory=dict,
)
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
@ -81,18 +76,6 @@ class VariablePool(BaseModel):
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
# Add memory blocks to the variable pool
for memory_id, memory_value in self.memory_blocks.items():
self.add(
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id],
VersionedMemoryVariable(
value=VersionedMemoryValue(
current_value=memory_value,
versions={"1": memory_value},
),
name=memory_id,
)
)
def add(self, selector: Sequence[str], value: Any, /):
"""

View File

@ -39,16 +39,14 @@ elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
export PORT=${DIFY_PORT:-5001}
exec python -m app
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:socketio_app
app:app
fi
fi

View File

@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@ -134,38 +134,22 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()

View File

@ -31,13 +31,7 @@ def init_app(app: DifyApp):
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=[
"Content-Type",
"Authorization",
HEADER_NAME_APP_CODE,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)

View File

@ -1,3 +0,0 @@
import socketio
sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*")

View File

@ -21,8 +21,6 @@ from core.variables.segments import (
ObjectSegment,
Segment,
StringSegment,
VersionedMemorySegment,
VersionedMemoryValue,
)
from core.variables.types import SegmentType
from core.variables.variables import (
@ -41,7 +39,6 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VersionedMemoryVariable,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -72,7 +69,6 @@ SEGMENT_TO_VARIABLE_MAP = {
NoneSegment: NoneVariable,
ObjectSegment: ObjectVariable,
StringSegment: StringVariable,
VersionedMemorySegment: VersionedMemoryVariable
}
@ -197,7 +193,6 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.FILE: FileSegment,
SegmentType.BOOLEAN: BooleanSegment,
SegmentType.OBJECT: ObjectSegment,
SegmentType.VERSIONED_MEMORY: VersionedMemorySegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
@ -264,12 +259,6 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
if segment_type == SegmentType.VERSIONED_MEMORY:
return VersionedMemorySegment(
value_type=segment_type,
value=VersionedMemoryValue.model_validate(value)
)
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:

View File

@ -1,17 +0,0 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@ -1,96 +0,0 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@ -1,90 +0,0 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: 68519ad5cd18
Create Date: 2025-08-22 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '227822d22895'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_comments',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('position_x', sa.Float(), nullable=False),
sa.Column('position_y', sa.Float(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('resolved_at', sa.DateTime(), nullable=True),
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
)
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_replies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
)
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_mentions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
)
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.drop_index('comment_mentions_user_idx')
batch_op.drop_index('comment_mentions_reply_idx')
batch_op.drop_index('comment_mentions_comment_idx')
op.drop_table('workflow_comment_mentions')
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.drop_index('comment_replies_created_at_idx')
batch_op.drop_index('comment_replies_comment_idx')
op.drop_table('workflow_comment_replies')
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.drop_index('workflow_comments_created_at_idx')
batch_op.drop_index('workflow_comments_app_idx')
op.drop_table('workflow_comments')
# ### end Alembic commands ###

View File

@ -1,104 +0,0 @@
"""add table credit pool
Revision ID: 58a70d22fdbd
Revises: 68519ad5cd18
Create Date: 2025-09-25 15:20:40.367078
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '58a70d22fdbd'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move quota data from providers to tenant_credit_pools
migrate_quota_data()
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###
def migrate_quota_data():
"""
Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
"""
# Create connection
bind = op.get_bind()
# Define quota type mappings
quota_type_mappings = ['trial', 'paid']
for quota_type in quota_type_mappings:
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = :quota_type
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""")
existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type
}).scalar()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -1,92 +0,0 @@
"""add table explore banner and trial
Revision ID: 3993fd9e9c2f
Revises: 68519ad5cd18
Create Date: 2025-10-11 14:42:01.954865
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3993fd9e9c2f'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.TEXT(),
type_=sa.String(length=255),
existing_nullable=True)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.String(length=255),
type_=sa.TEXT(),
existing_nullable=True)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@ -1,104 +0,0 @@
"""add_chatflow_memory_tables
Revision ID: d00b2b40ea3e
Revises: 68519ad5cd18
Create Date: 2025-10-11 15:29:20.244675
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd00b2b40ea3e'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chatflow_conversations',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('original_conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_metadata', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_conversations_pkey')
)
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.create_index('chatflow_conversations_original_conversation_id_idx', ['tenant_id', 'app_id', 'node_id', 'original_conversation_id'], unique=False)
op.create_table('chatflow_memory_variables',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=True),
sa.Column('conversation_id', models.types.StringUUID(), nullable=True),
sa.Column('node_id', sa.Text(), nullable=True),
sa.Column('memory_id', sa.Text(), nullable=False),
sa.Column('value', sa.Text(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('scope', sa.String(length=10), nullable=False),
sa.Column('term', sa.String(length=20), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('created_by_role', sa.String(length=20), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_memory_variables_pkey')
)
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.create_index('chatflow_memory_variables_memory_id_idx', ['tenant_id', 'app_id', 'node_id', 'memory_id'], unique=False)
op.create_table('chatflow_messages',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
sa.Column('index', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('data', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='chatflow_messages_pkey')
)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.create_index('chatflow_messages_version_idx', ['conversation_id', 'index', 'version'], unique=False)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.TEXT(),
type_=sa.String(length=255),
existing_nullable=True)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.alter_column('avatar_url',
existing_type=sa.String(length=255),
type_=sa.TEXT(),
existing_nullable=True)
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
batch_op.drop_index('chatflow_messages_version_idx')
op.drop_table('chatflow_messages')
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
batch_op.drop_index('chatflow_memory_variables_memory_id_idx')
op.drop_table('chatflow_memory_variables')
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
batch_op.drop_index('chatflow_conversations_original_conversation_id_idx')
op.drop_table('chatflow_conversations')
# ### end Alembic commands ###

View File

@ -22,55 +22,6 @@ def upgrade():
batch_op.add_column(sa.Column('app_mode', sa.String(length=255), nullable=True))
batch_op.create_index('message_app_mode_idx', ['app_mode'], unique=False)
conn = op.get_bind()
# Strategy: Update in batches to minimize lock time
# For large tables (millions of rows), this prevents long-running transactions
batch_size = 10000
print("Starting backfill of app_mode from conversations...")
# Use a more efficient UPDATE with JOIN
# This query updates messages.app_mode from conversations.mode
# Using string formatting for LIMIT since it's a constant
update_query = f"""
UPDATE messages m
SET app_mode = c.mode
FROM conversations c
WHERE m.conversation_id = c.id
AND m.app_mode IS NULL
AND m.id IN (
SELECT id FROM messages
WHERE app_mode IS NULL
LIMIT {batch_size}
)
"""
# Execute batched updates
total_updated = 0
iteration = 0
while True:
iteration += 1
result = conn.execute(sa.text(update_query))
# Check if result is None or has no rowcount
if result is None:
print("Warning: Query returned None, stopping backfill")
break
rows_updated = result.rowcount if hasattr(result, 'rowcount') else 0
total_updated += rows_updated
if rows_updated == 0:
break
print(f"Iteration {iteration}: Updated {rows_updated} messages (total: {total_updated})")
# For very large tables, add a small delay to reduce load
# Uncomment if needed: import time; time.sleep(0.1)
print(f"Backfill completed. Total messages updated: {total_updated}")
# ### end Alembic commands ###

View File

@ -9,12 +9,6 @@ from .account import (
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
from .comment import (
WorkflowComment,
WorkflowCommentMention,
WorkflowCommentReply,
)
from .dataset import (
AppDatasetJoin,
Dataset,
@ -34,7 +28,6 @@ from .dataset import (
)
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@ -47,7 +40,6 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
Message,
@ -61,9 +53,7 @@ from .model import (
Site,
Tag,
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -109,7 +99,6 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@ -123,9 +112,6 @@ __all__ = [
"BuiltinToolProvider",
"CeleryTask",
"CeleryTaskSet",
"ChatflowConversation",
"ChatflowMemoryVariable",
"ChatflowMessage",
"Conversation",
"ConversationVariable",
"CreatorUserRole",
@ -146,7 +132,6 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@ -175,7 +160,6 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
"TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@ -185,16 +169,12 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowComment",
"WorkflowCommentMention",
"WorkflowCommentReply",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",

View File

@ -1,76 +0,0 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .types import StringUUID
class ChatflowMemoryVariable(Base):
__tablename__ = "chatflow_memory_variables"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_memory_variables_pkey"),
sa.Index("chatflow_memory_variables_memory_id_idx", "tenant_id", "app_id", "node_id", "memory_id"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
memory_id: Mapped[str] = mapped_column(sa.Text, nullable=False)
value: Mapped[str] = mapped_column(sa.Text, nullable=False)
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node'
term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent'
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account`
created_by: Mapped[str] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowConversation(Base):
__tablename__ = "chatflow_conversations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_conversations_pkey"),
sa.Index(
"chatflow_conversations_original_conversation_id_idx",
"tenant_id", "app_id", "node_id", "original_conversation_id"
),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
original_conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
conversation_metadata: Mapped[str] = mapped_column(sa.Text, nullable=False) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ChatflowMessage(Base):
__tablename__ = "chatflow_messages"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="chatflow_messages_pkey"),
sa.Index("chatflow_messages_version_idx", "conversation_id", "index", "version"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
index: Mapped[int] = mapped_column(sa.Integer, nullable=False) # This index starts from 0
version: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Serialized PromptMessage JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@ -1,189 +0,0 @@
"""Workflow comment models."""
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
if TYPE_CHECKING:
pass
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
@property
def resolved_by_account(self):
"""Get resolver account."""
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[Optional[str]] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
return db.session.get(Account, self.mentioned_user_id)

View File

@ -23,7 +23,6 @@ class DraftVariableType(StrEnum):
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"
MEMORY_BLOCK = "memory_block"
class MessageStatus(StrEnum):

View File

@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@ -582,64 +582,6 @@ class InstalledApp(Base):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
class OAuthProviderApp(Base):
"""
Globally shared OAuth provider app information.
@ -1836,7 +1778,7 @@ class MessageAgentThought(Base):
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String, nullable=True)
currency: Mapped[str | None] = mapped_column()
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
@ -2015,29 +1957,3 @@ class TraceAppConfig(Base):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
class TenantCreditPool(Base):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def remaining_credits(self) -> int:
return max(0, self.quota_limit - self.quota_used)
def has_sufficient_credits(self, required_credits: int) -> bool:
return self.remaining_credits >= required_credits

View File

@ -1,6 +1,5 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -12,13 +11,10 @@ from sqlalchemy import DateTime, Select, exists, orm, select
from core.file.constants import maybe_file_object
from core.file.models import File
from core.memory.entities import MemoryBlockSpec
from core.variables import utils as variable_utils
from core.variables.segments import VersionedMemoryValue
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
MEMORY_BLOCK_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import NodeType, WorkflowExecutionStatus
@ -156,9 +152,6 @@ class Workflow(Base):
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
_memory_blocks: Mapped[str] = mapped_column(
"memory_blocks", sa.Text, nullable=False, server_default="[]"
)
VERSION_DRAFT = "draft"
@ -176,7 +169,6 @@ class Workflow(Base):
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
@ -192,7 +184,6 @@ class Workflow(Base):
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.memory_blocks = memory_blocks or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = naive_utc_now()
@ -352,7 +343,7 @@ class Workflow(Base):
:return: hash
"""
entity = {"graph": self.graph_dict}
entity = {"graph": self.graph_dict, "features": self.features_dict}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
@ -461,7 +452,7 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
"rag_pipeline_variables": self.rag_pipeline_variables,
}
return result
@ -499,27 +490,6 @@ class Workflow(Base):
ensure_ascii=False,
)
@property
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
"""Memory blocks configuration stored in database"""
if self._memory_blocks is None or self._memory_blocks == "":
self._memory_blocks = "[]"
memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
return results
@memory_blocks.setter
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
if not value:
self._memory_blocks = "[]"
return
self._memory_blocks = json.dumps(
[block.model_dump() for block in value],
ensure_ascii=False,
)
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
@ -1551,31 +1521,6 @@ class WorkflowDraftVariable(Base):
variable.editable = editable
return variable
@staticmethod
def new_memory_block_variable(
*,
app_id: str,
node_id: str | None = None,
memory_id: str,
name: str,
value: VersionedMemoryValue,
description: str = "",
) -> "WorkflowDraftVariable":
"""Create a new memory block draft variable."""
return WorkflowDraftVariable(
id=str(uuid.uuid4()),
app_id=app_id,
node_id=MEMORY_BLOCK_VARIABLE_NODE_ID,
name=name,
value=value.model_dump_json(),
description=description,
selector=[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id] if node_id is None else
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id, node_id],
value_type=SegmentType.VERSIONED_MEMORY,
visible=True,
editable=True,
)
@property
def edited(self):
return self.last_edited_at is not None

View File

@ -20,7 +20,6 @@ dependencies = [
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent-websocket~=0.10.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@ -70,7 +69,6 @@ dependencies = [
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"python-socketio~=5.13.0",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.1.0",

View File

@ -999,11 +999,6 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
from services.credit_pool_service import CreditPoolService
CreditPoolService.create_default_pool(tenant.id)
return tenant
@staticmethod

View File

@ -1,237 +0,0 @@
import json
from collections.abc import MutableMapping, Sequence
from typing import Literal, Optional, overload
from sqlalchemy import Row, Select, and_, func, select
from sqlalchemy.orm import Session
from core.memory.entities import ChatflowConversationMetadata
from core.model_runtime.entities.message_entities import (
PromptMessage,
)
from extensions.ext_database import db
from models.chatflow_memory import ChatflowConversation, ChatflowMessage
class ChatflowHistoryService:
@staticmethod
def get_visible_chat_history(
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
max_visible_count: Optional[int] = None
) -> Sequence[PromptMessage]:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
return []
metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
visible_count: int = max_visible_count or metadata.visible_count
stmt = select(ChatflowMessage).where(
ChatflowMessage.conversation_id == chatflow_conv.id
).order_by(ChatflowMessage.index.asc(), ChatflowMessage.version.desc())
raw_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(stmt).all()
sorted_messages = ChatflowHistoryService._filter_latest_messages(
[it[0] for it in raw_messages]
)
visible_count = min(visible_count, len(sorted_messages))
visible_messages = sorted_messages[-visible_count:]
return [PromptMessage.model_validate_json(it.data) for it in visible_messages]
@staticmethod
def save_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Get next index
max_index = session.execute(
select(func.max(ChatflowMessage.index)).where(
ChatflowMessage.conversation_id == chatflow_conv.id
)
).scalar() or -1
next_index = max_index + 1
# Save new message to append-only table
new_message = ChatflowMessage(
conversation_id=chatflow_conv.id,
index=next_index,
version=1,
data=json.dumps(prompt_message)
)
session.add(new_message)
session.commit()
# 添加每次保存消息后简单增长visible_count
current_metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
new_visible_count = current_metadata.visible_count + 1
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
@staticmethod
def save_app_message(
prompt_message: PromptMessage,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
"""Save PromptMessage to app-level chatflow conversation."""
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=None
)
@staticmethod
def save_node_message(
prompt_message: PromptMessage,
node_id: str,
conversation_id: str,
app_id: str,
tenant_id: str
) -> None:
ChatflowHistoryService.save_message(
prompt_message=prompt_message,
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id
)
@staticmethod
def update_visible_count(
conversation_id: str,
node_id: Optional[str],
new_visible_count: int,
app_id: str,
tenant_id: str
) -> None:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
)
# Only update visible_count in metadata, do not delete any data
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
session.commit()
@staticmethod
def get_conversation_metadata(
tenant_id: str,
app_id: str,
conversation_id: str,
node_id: Optional[str]
) -> ChatflowConversationMetadata:
with Session(db.engine) as session:
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
)
if not chatflow_conv:
raise ValueError(f"Conversation not found: {conversation_id}")
return ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
@staticmethod
def _filter_latest_messages(raw_messages: Sequence[ChatflowMessage]) -> Sequence[ChatflowMessage]:
index_to_message: MutableMapping[int, ChatflowMessage] = {}
for msg in raw_messages:
index = msg.index
if index not in index_to_message or msg.version > index_to_message[index].version:
index_to_message[index] = msg
sorted_messages = sorted(index_to_message.values(), key=lambda m: m.index)
return sorted_messages
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[True] = True
) -> ChatflowConversation: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: Literal[False] = False
) -> Optional[ChatflowConversation]: ...
@overload
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]: ...
@staticmethod
def _get_or_create_chatflow_conversation(
session: Session,
conversation_id: str,
app_id: str,
tenant_id: str,
node_id: Optional[str] = None,
create_if_missing: bool = False
) -> Optional[ChatflowConversation]:
"""Get existing chatflow conversation or optionally create new one"""
stmt: Select[tuple[ChatflowConversation]] = select(ChatflowConversation).where(
and_(
ChatflowConversation.original_conversation_id == conversation_id,
ChatflowConversation.tenant_id == tenant_id,
ChatflowConversation.app_id == app_id
)
)
if node_id:
stmt = stmt.where(ChatflowConversation.node_id == node_id)
else:
stmt = stmt.where(ChatflowConversation.node_id.is_(None))
chatflow_conv: Row[tuple[ChatflowConversation]] | None = session.execute(stmt).first()
if chatflow_conv:
result: ChatflowConversation = chatflow_conv[0] # Extract the ChatflowConversation object
return result
else:
if create_if_missing:
# Create a new chatflow conversation
default_metadata = ChatflowConversationMetadata(visible_count=0)
new_chatflow_conv = ChatflowConversation(
tenant_id=tenant_id,
app_id=app_id,
node_id=node_id,
original_conversation_id=conversation_id,
conversation_metadata=default_metadata.model_dump_json(),
)
session.add(new_chatflow_conv)
session.flush() # Obtain ID
return new_chatflow_conv
return None

View File

@ -1,670 +0,0 @@
import logging
import threading
import time
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import and_, delete, select
from sqlalchemy.orm import Session
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.entities import (
MemoryBlock,
MemoryBlockSpec,
MemoryBlockWithConversation,
MemoryCreatedBy,
MemoryScheduleMode,
MemoryScope,
MemoryTerm,
MemoryValueData,
)
from core.memory.errors import MemorySyncTimeoutError
from core.model_runtime.entities.message_entities import PromptMessage
from core.variables.segments import VersionedMemoryValue
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import App, CreatorUserRole
from models.chatflow_memory import ChatflowMemoryVariable
from models.workflow import Workflow, WorkflowDraftVariable
from services.chatflow_history_service import ChatflowHistoryService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
app: App, created_by: MemoryCreatedBy, version: int | None = None
) -> Sequence[MemoryBlock]:
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
if version is None:
# If version not specified, get the latest version
stmt = (
select(ChatflowMemoryVariable)
.distinct(ChatflowMemoryVariable.memory_id)
.where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
)
.order_by(ChatflowMemoryVariable.version.desc())
)
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
ChatflowMemoryVariable.version == version,
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def get_session_memories(
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = (
select(ChatflowMemoryVariable)
.distinct(ChatflowMemoryVariable.memory_id)
.where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
)
)
.order_by(ChatflowMemoryVariable.version.desc())
)
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
ChatflowMemoryVariable.version == version,
)
)
with Session(db.engine) as session:
db_results = session.execute(stmt).all()
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
@staticmethod
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id
variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value)
if memory.created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by = memory.created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by = memory.created_by.id
with Session(db.engine) as session:
session.add(
ChatflowMemoryVariable(
memory_id=memory.spec.id,
tenant_id=memory.tenant_id,
app_id=memory.app_id,
node_id=memory.node_id,
conversation_id=memory.conversation_id,
name=memory.spec.name,
value=MemoryValueData(value=memory.value, edited_by_user=memory.edited_by_user).model_dump_json(),
term=memory.spec.term,
scope=memory.spec.scope,
version=memory.version, # Use version from MemoryBlock directly
created_by_role=created_by_role,
created_by=created_by,
)
)
session.commit()
if is_draft:
with Session(bind=db.engine) as session:
draft_var_service = WorkflowDraftVariableService(session)
memory_selector = memory.spec.id if not memory.node_id else f"{memory.node_id}.{memory.spec.id}"
existing_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=memory.app_id, selectors=[[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_selector]]
)
if existing_vars:
draft_var = existing_vars[0]
draft_var.value = (
VersionedMemoryValue.model_validate_json(draft_var.value)
.add_version(memory.value)
.model_dump_json()
)
else:
draft_var = WorkflowDraftVariable.new_memory_block_variable(
app_id=memory.app_id,
memory_id=memory.spec.id,
name=memory.spec.name,
value=VersionedMemoryValue().add_version(memory.value),
description=memory.spec.description,
)
session.add(draft_var)
session.commit()
@staticmethod
def get_memories_by_specs(
memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool,
) -> Sequence[MemoryBlock]:
return [
ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
)
for spec in memory_block_specs
]
@staticmethod
def get_memory_by_spec(
spec: MemoryBlockSpec,
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool,
) -> MemoryBlock:
with Session(db.engine) as session:
if is_draft:
draft_var_service = WorkflowDraftVariableService(session)
selector = (
[MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"]
if node_id
else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
)
draft_vars = draft_var_service.get_draft_variables_by_selectors(app_id=app_id, selectors=[selector])
if draft_vars:
draft_var = draft_vars[0]
return MemoryBlock(
value=draft_var.get_value().text,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
stmt = (
select(ChatflowMemoryVariable)
.where(
and_(
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id == (node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id
== (conversation_id if spec.term == MemoryTerm.SESSION else None),
)
)
.order_by(ChatflowMemoryVariable.version.desc())
.limit(1)
)
result = session.execute(stmt).scalar()
if result:
memory_value_data = MemoryValueData.model_validate_json(result.value)
return MemoryBlock(
value=memory_value_data.value,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=result.version,
)
return MemoryBlock(
tenant_id=tenant_id,
value=spec.template,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
spec=spec,
created_by=created_by,
version=1,
)
@staticmethod
def update_app_memory_if_needed(
workflow: Workflow,
conversation_id: str,
variable_pool: VariablePool,
created_by: MemoryCreatedBy,
is_draft: bool,
):
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=workflow.app_id,
tenant_id=workflow.tenant_id,
node_id=None,
)
sync_blocks: list[MemoryBlock] = []
async_blocks: list[MemoryBlock] = []
for memory_spec in workflow.memory_blocks:
if memory_spec.scope == MemoryScope.APP:
memory = ChatflowMemoryService.get_memory_by_spec(
spec=memory_spec,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
conversation_id=conversation_id,
node_id=None,
is_draft=is_draft,
created_by=created_by,
)
if ChatflowMemoryService._should_update_memory(memory, visible_messages):
if memory.spec.schedule_mode == MemoryScheduleMode.SYNC:
sync_blocks.append(memory)
else:
async_blocks.append(memory)
if not sync_blocks and not async_blocks:
return
# async mode: submit individual async tasks directly
for memory_block in async_blocks:
ChatflowMemoryService._app_submit_async_memory_update(
block=memory_block,
is_draft=is_draft,
variable_pool=variable_pool,
visible_messages=visible_messages,
conversation_id=conversation_id,
)
# sync mode: submit a batch update task
if sync_blocks:
ChatflowMemoryService._app_submit_sync_memory_batch_update(
sync_blocks=sync_blocks,
is_draft=is_draft,
conversation_id=conversation_id,
app_id=workflow.app_id,
visible_messages=visible_messages,
variable_pool=variable_pool,
)
@staticmethod
def update_node_memory_if_needed(
tenant_id: str,
app_id: str,
node_id: str,
created_by: MemoryCreatedBy,
conversation_id: str,
memory_block_spec: MemoryBlockSpec,
variable_pool: VariablePool,
is_draft: bool,
) -> bool:
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
app_id=app_id,
tenant_id=tenant_id,
node_id=node_id,
)
memory_block = ChatflowMemoryService.get_memory_by_spec(
spec=memory_block_spec,
tenant_id=tenant_id,
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
is_draft=is_draft,
created_by=created_by,
)
if not ChatflowMemoryService._should_update_memory(memory_block=memory_block, visible_history=visible_messages):
return False
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
# Node-level sync: blocking execution
ChatflowMemoryService._update_node_memory_sync(
visible_messages=visible_messages,
memory_block=memory_block,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id,
)
else:
# Node-level async: execute asynchronously
ChatflowMemoryService._update_node_memory_async(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id,
)
return True
@staticmethod
def wait_for_sync_memory_completion(workflow: Workflow, conversation_id: str):
"""Wait for sync memory update to complete, maximum 50 seconds"""
memory_blocks = workflow.memory_blocks
sync_memory_blocks = [
block
for block in memory_blocks
if block.scope == MemoryScope.APP and block.schedule_mode == MemoryScheduleMode.SYNC
]
if not sync_memory_blocks:
return
lock_key = _get_memory_sync_lock_key(workflow.app_id, conversation_id)
# Retry up to 10 times, wait 5 seconds each time, total 50 seconds
max_retries = 10
retry_interval = 5
for i in range(max_retries):
if not redis_client.exists(lock_key):
# Lock doesn't exist, can continue
return
if i < max_retries - 1:
# Still have retry attempts, wait
time.sleep(retry_interval)
else:
# Maximum retry attempts reached, raise exception
raise MemorySyncTimeoutError(app_id=workflow.app_id, conversation_id=conversation_id)
@staticmethod
def _convert_to_memory_blocks(
app: App, created_by: MemoryCreatedBy, raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
return []
results = []
for chatflow_memory_variable in raw_results:
spec = next(
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id), None
)
if spec and chatflow_memory_variable.app_id:
memory_value_data = MemoryValueData.model_validate_json(chatflow_memory_variable.value)
results.append(
MemoryBlock(
spec=spec,
tenant_id=chatflow_memory_variable.tenant_id,
value=memory_value_data.value,
app_id=chatflow_memory_variable.app_id,
conversation_id=chatflow_memory_variable.conversation_id,
node_id=chatflow_memory_variable.node_id,
edited_by_user=memory_value_data.edited_by_user,
created_by=created_by,
version=chatflow_memory_variable.version,
)
)
return results
@staticmethod
def _should_update_memory(memory_block: MemoryBlock, visible_history: Sequence[PromptMessage]) -> bool:
return len(visible_history) >= memory_block.spec.update_turns
@staticmethod
def _app_submit_async_memory_update(
block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool,
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
"memory_block": block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
)
thread.start()
@staticmethod
def _app_submit_sync_memory_batch_update(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool,
):
"""Submit sync memory batch update task"""
thread = threading.Thread(
target=ChatflowMemoryService._batch_update_sync_memory,
kwargs={
"sync_blocks": sync_blocks,
"app_id": app_id,
"conversation_id": conversation_id,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
},
)
thread.start()
@staticmethod
def _batch_update_sync_memory(
sync_blocks: Sequence[MemoryBlock],
app_id: str,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool,
):
try:
lock_key = _get_memory_sync_lock_key(app_id, conversation_id)
with redis_client.lock(lock_key, timeout=120):
threads = []
for block in sync_blocks:
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
"memory_block": block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
except Exception as e:
logger.exception("Error batch updating memory", exc_info=e)
@staticmethod
def _update_node_memory_sync(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool,
):
ChatflowMemoryService._perform_memory_update(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id,
)
@staticmethod
def _update_node_memory_async(
memory_block: MemoryBlock,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool = False,
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
"memory_block": memory_block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
daemon=True,
)
thread.start()
@staticmethod
def _perform_memory_update(
memory_block: MemoryBlock,
variable_pool: VariablePool,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
is_draft: bool,
):
updated_value = LLMGenerator.update_memory_block(
tenant_id=memory_block.tenant_id,
visible_history=ChatflowMemoryService._format_chat_history(visible_messages),
variable_pool=variable_pool,
memory_block=memory_block,
memory_spec=memory_block.spec,
)
updated_memory = MemoryBlock(
tenant_id=memory_block.tenant_id,
value=updated_value,
spec=memory_block.spec,
app_id=memory_block.app_id,
conversation_id=conversation_id,
node_id=memory_block.node_id,
edited_by_user=False,
created_by=memory_block.created_by,
version=memory_block.version + 1, # Increment version for business logic update
)
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
ChatflowHistoryService.update_visible_count(
conversation_id=conversation_id,
node_id=memory_block.node_id,
new_visible_count=memory_block.spec.preserved_turns,
app_id=memory_block.app_id,
tenant_id=memory_block.tenant_id,
)
@staticmethod
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
raise ValueError("Workflow not found")
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
if not memory_spec or not memory_spec.end_user_editable:
raise ValueError("Memory not found or not deletable")
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.memory_id == memory_id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
)
session.execute(stmt)
session.commit()
@staticmethod
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
created_by_id = created_by.account_id
else:
created_by_role = CreatorUserRole.END_USER
created_by_id = created_by.id
with Session(db.engine) as session:
stmt = delete(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
)
session.execute(stmt)
session.commit()
@staticmethod
def get_persistent_memories_with_conversation(
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get persistent memories with conversation metadata (always None for persistent)"""
memory_blocks = ChatflowMemoryService.get_persistent_memories(app, created_by, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(app.tenant_id, app.id, conversation_id, block.node_id),
)
for block in memory_blocks
]
@staticmethod
def get_session_memories_with_conversation(
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get session memories with conversation metadata"""
memory_blocks = ChatflowMemoryService.get_session_memories(app, created_by, conversation_id, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(app.tenant_id, app.id, conversation_id, block.node_id),
)
for block in memory_blocks
]
@staticmethod
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
result = []
for message in messages:
result.append((str(message.role.value), message.get_text_content()))
return result
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
"""Generate Redis lock key for memory sync updates
Args:
app_id: Application ID
conversation_id: Conversation ID
Returns:
Formatted lock key
"""
return f"memory_sync_update:{app_id}:{conversation_id}"

View File

@ -1,67 +0,0 @@
import logging
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
logger = logging.getLogger(__name__)
class CreditPoolService:
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(credit_pool)
db.session.commit()
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
):
"""check and deduct credits"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits < credits_required:
raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
)
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
where_conditions = [
TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
session.execute(stmt)
session.commit()
except Exception:
raise QuotaExceededError("Failed to deduct credits")

View File

@ -50,7 +50,6 @@ from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@ -80,6 +79,7 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
@ -1694,7 +1694,7 @@ class DocumentService:
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

View File

@ -1,82 +0,0 @@
import logging
from collections.abc import Callable
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: list[str]):
self.tenant_id = tenant_id
self.dataset_id = dataset_id
self.document_ids = document_ids
self.tenant_self_task_queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self.tenant_id)
def _send_to_direct_queue(self, task_func: Callable):
logger.info("send dataset %s to direct queue", self.dataset_id)
task_func.delay( # type: ignore
tenant_id=self.tenant_id,
dataset_id=self.dataset_id,
document_ids=self.document_ids
)
def _send_to_tenant_queue(self, task_func: Callable):
logger.info("send dataset %s to tenant queue", self.dataset_id)
if self.tenant_self_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self.tenant_self_task_queue.push_tasks([
asdict(
DocumentTask(tenant_id=self.tenant_id, dataset_id=self.dataset_id, document_ids=self.document_ids)
)
])
logger.info("push tasks: %s - %s", self.dataset_id, self.document_ids)
else:
# Set flag and execute task
self.tenant_self_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self.tenant_id,
dataset_id=self.dataset_id,
document_ids=self.document_ids
)
logger.info("init tasks: %s - %s", self.dataset_id, self.document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self.tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == "sandbox":
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@ -132,7 +132,6 @@ class FeatureModel(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
next_credit_reset_date: int = 0
class KnowledgeRateLimitModel(BaseModel):
@ -153,7 +152,6 @@ class SystemFeatureModel(BaseModel):
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
enable_collaboration_mode: bool = False
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_email_setup: bool = False
@ -163,8 +161,6 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
enable_trial_app: bool = False
enable_explore_banner: bool = False
class FeatureService:
@ -217,12 +213,9 @@ class FeatureService:
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
@ -287,9 +280,6 @@ class FeatureService:
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
if "next_credit_reset_date" in billing_info:
features.next_credit_reset_date = billing_info["next_credit_reset_date"]
@classmethod
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):

View File

@ -1,101 +0,0 @@
import json
import logging
from collections.abc import Callable
from functools import cached_property
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from extensions.ext_database import db
from services.feature_service import FeatureService
from services.file_service import FileService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class RagPipelineTaskProxy:
def __init__(
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity]
):
self.dataset_tenant_id = dataset_tenant_id
self.user_id = user_id
self.rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
self.tenant_self_pipeline_task_queue = TenantSelfTaskQueue(dataset_tenant_id, "pipeline")
@cached_property
def features(self):
return FeatureService.get_features(self.dataset_tenant_id)
def _upload_invoke_entities(self) -> str:
text = [item.model_dump() for item in self.rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, self.user_id, self.dataset_tenant_id)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable):
logger.info("send file %s to direct queue", upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self.dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable):
logger.info("send file %s to tenant queue", upload_file_id)
if self.tenant_self_pipeline_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self.tenant_self_pipeline_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
else:
# Set flag and execute task
self.tenant_self_pipeline_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self.dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
def _send_to_priority_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
def _send_to_priority_direct_queue(self, upload_file_id: str):
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
def _dispatch(self):
upload_file_id = self._upload_invoke_entities()
if not upload_file_id:
raise ValueError("upload_file_id is empty")
logger.info(
"dispatch args: %s - %s - %s",
self.dataset_tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan
)
# dispatch to different pipeline queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == "sandbox":
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
self._send_to_default_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue with tenant isolation for other plans
self._send_to_priority_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue(upload_file_id)
def delay(self):
if not self.rag_pipeline_invoke_entities:
logger.warning(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
self.dataset_tenant_id,
self.user_id
)
return
self._dispatch()

View File

@ -1,7 +1,4 @@
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -23,15 +20,6 @@ class RecommendedAppService:
)
)
if FeatureService.get_system_features().enable_trial_app:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
app["can_trial"] = True
else:
app["can_trial"] = False
return result
@classmethod
@ -44,29 +32,4 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
result["can_trial"] = True
else:
result["can_trial"] = False
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
account_trial_app_record = db.session.query(AccountTrialAppRecord).where(
AccountTrialAppRecord.app_id == app_id,
AccountTrialAppRecord.account_id == account_id
).first()
if account_trial_app_record:
account_trial_app_record.count += 1
db.session.commit()
else:
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
db.session.commit()

View File

@ -7,7 +7,6 @@ from pydantic import ValidationError
from yarl import URL
from configs import dify_config
from core.entities.mcp_provider import MCPConfiguration
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
@ -240,6 +239,7 @@ class ToolTransformService:
user_name: str | None = None,
include_sensitive: bool = True,
) -> ToolProviderApiEntity:
from core.entities.mcp_provider import MCPConfiguration
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None:

View File

@ -1,311 +0,0 @@
import logging
from typing import Optional
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
return comments
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: Optional[list[str]] = None,
) -> WorkflowComment:
"""Create a new workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
session.commit()
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: Optional[float] = None,
position_y: Optional[float] = None,
mentioned_user_ids: Optional[list[str]] = None,
) -> dict:
"""Update a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
session.commit()
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None
) -> dict:
"""Add a reply to a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
# Create mention linking to specific reply
mention = WorkflowCommentMention(
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
)
session.add(mention)
session.commit()
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(
reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None
) -> WorkflowCommentReply:
"""Update a comment reply."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@ -11,7 +11,6 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.memory.entities import MemoryBlockSpec, MemoryCreatedBy, MemoryScope
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
@ -198,18 +197,15 @@ class WorkflowService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
force_upload: bool = False,
) -> Workflow:
"""
Sync draft workflow
:param force_upload: Skip hash validation when True (for restore operations)
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow and workflow.unique_hash != unique_hash and not force_upload:
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure
@ -228,7 +224,6 @@ class WorkflowService:
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow.memory_blocks = memory_blocks or []
db.session.add(workflow)
# update draft workflow if found
else:
@ -238,7 +233,6 @@ class WorkflowService:
workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.memory_blocks = memory_blocks or []
# commit db session changes
db.session.commit()
@ -249,78 +243,6 @@ class WorkflowService:
# return draft workflow
return workflow
def update_draft_workflow_environment_variables(
self,
*,
app_model: App,
environment_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow environment variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.environment_variables = environment_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_conversation_variables(
self,
*,
app_model: App,
conversation_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow conversation variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.conversation_variables = conversation_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_features(
self,
*,
app_model: App,
features: dict,
account: Account,
):
"""
Update draft workflow features
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def publish_workflow(
self,
*,
@ -358,7 +280,6 @@ class WorkflowService:
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
memory_blocks=draft_workflow.memory_blocks,
features=draft_workflow.features,
)
@ -715,10 +636,17 @@ class WorkflowService:
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
# init variable pool
variable_pool = _setup_variable_pool(query=query, files=files or [], user_id=account.id,
user_inputs=user_inputs, workflow=draft_workflow,
node_type=node_type, conversation_id=conversation_id,
conversation_variables=[], is_draft=True)
variable_pool = _setup_variable_pool(
query=query,
files=files or [],
user_id=account.id,
user_inputs=user_inputs,
workflow=draft_workflow,
# NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
conversation_variables=[],
node_type=node_type,
conversation_id=conversation_id,
)
else:
variable_pool = VariablePool(
@ -1067,7 +995,6 @@ def _setup_variable_pool(
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
is_draft: bool
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
@ -1086,6 +1013,7 @@ def _setup_variable_pool(
system_variable.dialogue_count = 1
else:
system_variable = SystemVariable.empty()
# init variable pool
variable_pool = VariablePool(
system_variables=system_variable,
@ -1094,12 +1022,6 @@ def _setup_variable_pool(
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
memory_blocks=_fetch_memory_blocks(
workflow,
MemoryCreatedBy(account_id=user_id),
conversation_id,
is_draft=is_draft
),
)
return variable_pool
@ -1136,30 +1058,3 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
return build_from_mappings(mappings=value, tenant_id=tenant_id)
else:
raise Exception("unreachable")
def _fetch_memory_blocks(
workflow: Workflow,
created_by: MemoryCreatedBy,
conversation_id: str,
is_draft: bool
) -> Mapping[str, str]:
memory_blocks = {}
memory_block_specs = workflow.memory_blocks
from services.chatflow_memory_service import ChatflowMemoryService
memories = ChatflowMemoryService.get_memories_by_specs(
memory_block_specs=memory_block_specs,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
node_id=None,
conversation_id=conversation_id,
is_draft=is_draft,
created_by=created_by,
)
for memory in memories:
if memory.spec.scope == MemoryScope.APP:
memory_blocks[memory.spec.id] = memory.value
else: # NODE scope
memory_blocks[f"{memory.node_id}.{memory.spec.id}"] = memory.value
return memory_blocks

View File

@ -31,8 +31,7 @@ class WorkspaceService:
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
feature = FeatureService.get_features(tenant.id)
can_replace_logo = feature.can_replace_logo
can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL
@ -47,20 +46,5 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
if dify_config.EDITION == "CLOUD":
tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date
from services.credit_pool_service import CreditPoolService
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info

View File

@ -1,14 +1,11 @@
import logging
import time
from collections.abc import Callable
import click
from celery import shared_task
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantSelfTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -25,24 +22,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_document_indexing_task or priority_document_indexing_task instead.
Usage: document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
_document_indexing(dataset_id, document_ids)
def _document_indexing(dataset_id: str, document_ids: list):
"""
Process document for tasks
:param dataset_id:
:param document_ids:
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
@ -106,61 +87,3 @@ def _document_indexing(dataset_id: str, document_ids: list):
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
def _document_indexing_with_tenant_queue(tenant_id: str, dataset_id: str, document_ids: list, task_func: Callable):
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
finally:
tenant_self_task_queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_self_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_self_task_queue.delete_task_key()
@shared_task(queue="dataset")
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: list):
"""
Async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
@shared_task(queue="priority_dataset")
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: list):
"""
Priority async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)

View File

@ -12,10 +12,8 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
@ -24,8 +22,6 @@ from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
@ -73,27 +69,6 @@ def priority_rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = TenantSelfTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_self_pipeline_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_pipeline_task_queue.set_task_waiting_time()
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_self_pipeline_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@ -12,20 +12,17 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
@ -73,27 +70,26 @@ def rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = TenantSelfTaskQueue(tenant_id, "pipeline")
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_self_pipeline_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_pipeline_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_self_pipeline_task_queue.delete_task_key()
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@ -1 +0,0 @@
# Test containers integration tests for RAG pipeline

View File

@ -1 +0,0 @@
# Test containers integration tests for RAG pipeline queue

View File

@ -1,663 +0,0 @@
"""
Integration tests for TenantSelfTaskQueue using testcontainers.
These tests verify the Redis-based task queue functionality with real Redis instances,
testing tenant isolation, task serialization, and queue operations in a realistic environment.
Includes compatibility tests for migrating from legacy string-only queues.
All tests use generic naming to avoid coupling to specific business implementations.
"""
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
import pytest
from faker import Faker
from core.rag.pipeline.queue import TASK_WRAPPER_PREFIX, TaskWrapper, TenantSelfTaskQueue
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@dataclass
class TestTask:
"""Test task data structure for testing complex object serialization."""
task_id: str
tenant_id: str
data: dict[str, Any]
metadata: dict[str, Any]
class TestTenantSelfTaskQueueIntegration:
"""Integration tests for TenantSelfTaskQueue using testcontainers."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantSelfTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantSelfTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "test-key")
assert queue.tenant_id == tenant.id
assert queue.unique_key == "test-key"
assert queue.queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue.task_key == f"tenant_test-key_task:{tenant.id}"
assert queue.DEFAULT_TASK_TTL == 60 * 60
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
# Create second tenant
tenant2 = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant2)
db_session_with_containers.commit()
queue1 = TenantSelfTaskQueue(tenant1.id, "same-key")
queue2 = TenantSelfTaskQueue(tenant2.id, "same-key")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2.queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantSelfTaskQueue(tenant.id, "key1")
queue2 = TenantSelfTaskQueue(tenant.id, "key2")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == f"tenant_self_key1_task_queue:{tenant.id}"
assert queue2.queue == f"tenant_self_key2_task_queue:{tenant.id}"
def test_task_key_operations(self, test_queue):
"""Test task key operations (get, set, delete)."""
# Initially no task key should exist
assert test_queue.get_task_key() is None
# Set task waiting time with default TTL
test_queue.set_task_waiting_time()
task_key = test_queue.get_task_key()
# Redis returns bytes, convert to string for comparison
assert task_key in (b"1", "1")
# Set task waiting time with custom TTL
custom_ttl = 30
test_queue.set_task_waiting_time(custom_ttl)
task_key = test_queue.get_task_key()
assert task_key in (b"1", "1")
# Delete task key
test_queue.delete_task_key()
assert test_queue.get_task_key() is None
def test_push_and_pull_string_tasks(self, test_queue):
"""Test pushing and pulling string tasks."""
tasks = ["task1", "task2", "task3"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull tasks one by one (FIFO order)
pulled_tasks = []
for _ in range(3):
task = test_queue.get_next_task()
if task:
pulled_tasks.append(task)
# Should get tasks in FIFO order (lpush + rpop = FIFO)
assert pulled_tasks == ["task1", "task2", "task3"]
def test_push_and_pull_multiple_tasks(self, test_queue):
"""Test pushing and pulling multiple tasks at once."""
tasks = ["task1", "task2", "task3", "task4", "task5"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull multiple tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
assert pulled_tasks == ["task1", "task2", "task3"]
# Pull remaining tasks
remaining_tasks = test_queue.pull_tasks(5)
assert len(remaining_tasks) == 2
assert remaining_tasks == ["task4", "task5"]
def test_push_and_pull_complex_objects(self, test_queue, fake):
"""Test pushing and pulling complex object tasks."""
# Create complex task objects as dictionaries (not dataclass instances)
tasks = [
{
"task_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"data": {
"file_id": str(uuid4()),
"content": fake.text(),
"metadata": {"size": fake.random_int(1000, 10000)}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": fake.words(3)
}
},
{
"task_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"data": {
"file_id": str(uuid4()),
"content": "测试中文内容",
"metadata": {"size": fake.random_int(1000, 10000)}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": ["中文", "测试", "emoji🚀"]
}
}
]
# Push complex tasks
test_queue.push_tasks(tasks)
# Pull tasks
pulled_tasks = test_queue.pull_tasks(2)
assert len(pulled_tasks) == 2
# Verify deserialized tasks match original (FIFO order)
for i, pulled_task in enumerate(pulled_tasks):
original_task = tasks[i] # FIFO order
assert isinstance(pulled_task, dict)
assert pulled_task["task_id"] == original_task["task_id"]
assert pulled_task["tenant_id"] == original_task["tenant_id"]
assert pulled_task["data"] == original_task["data"]
assert pulled_task["metadata"] == original_task["metadata"]
def test_mixed_task_types(self, test_queue, fake):
"""Test pushing and pulling mixed string and object tasks."""
string_task = "simple_string_task"
object_task = {
"task_id": str(uuid4()),
"dataset_id": str(uuid4()),
"document_ids": [str(uuid4()) for _ in range(3)]
}
tasks = [string_task, object_task, "another_string"]
# Push mixed tasks
test_queue.push_tasks(tasks)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
# Verify types and content
assert pulled_tasks[0] == string_task
assert isinstance(pulled_tasks[1], dict)
assert pulled_tasks[1] == object_task
assert pulled_tasks[2] == "another_string"
def test_empty_queue_operations(self, test_queue):
"""Test operations on empty queue."""
# Pull from empty queue
tasks = test_queue.pull_tasks(5)
assert tasks == []
# Get next task from empty queue
task = test_queue.get_next_task()
assert task is None
# Pull zero or negative count
assert test_queue.pull_tasks(0) == []
assert test_queue.pull_tasks(-1) == []
def test_task_ttl_expiration(self, test_queue):
"""Test task key TTL expiration."""
# Set task with short TTL
short_ttl = 2
test_queue.set_task_waiting_time(short_ttl)
# Verify task key exists
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
# Wait for TTL to expire
time.sleep(short_ttl + 1)
# Verify task key has expired
assert test_queue.get_task_key() is None
def test_large_task_batch(self, test_queue, fake):
"""Test handling large batches of tasks."""
# Create large batch of tasks
large_batch = []
for i in range(100):
task = {
"task_id": str(uuid4()),
"index": i,
"data": fake.text(max_nb_chars=100),
"metadata": {"batch_id": str(uuid4())}
}
large_batch.append(task)
# Push large batch
test_queue.push_tasks(large_batch)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(100)
assert len(pulled_tasks) == 100
# Verify all tasks were retrieved correctly (FIFO order)
for i, task in enumerate(pulled_tasks):
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
# Create multiple queues for the same tenant
queue1 = TenantSelfTaskQueue(tenant.id, "queue1")
queue2 = TenantSelfTaskQueue(tenant.id, "queue2")
# Push tasks to different queues
queue1.push_tasks(["task1_queue1", "task2_queue1"])
queue2.push_tasks(["task1_queue2", "task2_queue2"])
# Verify queues are isolated
tasks1 = queue1.pull_tasks(2)
tasks2 = queue2.pull_tasks(2)
assert tasks1 == ["task1_queue1", "task2_queue1"]
assert tasks2 == ["task1_queue2", "task2_queue2"]
assert tasks1 != tasks2
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
"""Test TaskWrapper serialization and deserialization roundtrip."""
# Create complex nested data
complex_data = {
"id": str(uuid4()),
"nested": {
"deep": {
"value": "test",
"numbers": [1, 2, 3, 4, 5],
"unicode": "测试中文",
"emoji": "🚀"
}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": ["tag1", "tag2", "tag3"]
}
}
# Create wrapper and serialize
wrapper = TaskWrapper(complex_data)
serialized = wrapper.serialize()
# Verify serialization
assert isinstance(serialized, str)
assert "测试中文" in serialized
assert "🚀" in serialized
# Deserialize and verify
deserialized_wrapper = TaskWrapper.deserialize(serialized)
assert deserialized_wrapper.data == complex_data
def test_error_handling_invalid_json(self, test_queue):
"""Test error handling for invalid JSON in wrapped tasks."""
# Manually create invalid wrapped task
invalid_wrapped_task = f"{TASK_WRAPPER_PREFIX}invalid json data"
# Push invalid task directly to Redis
redis_client.lpush(test_queue.queue, invalid_wrapped_task)
# Pull task - should fall back to string
task = test_queue.get_next_task()
assert task == invalid_wrapped_task
def test_real_world_processing_scenario(self, test_queue, fake):
"""Test realistic task processing scenario."""
# Simulate various task types
tasks = []
for i in range(5):
task = {
"tenant_id": test_queue.tenant_id,
"resource_id": str(uuid4()),
"resource_ids": [str(uuid4()) for _ in range(fake.random_int(1, 5))]
}
tasks.append(task)
# Push all tasks
test_queue.push_tasks(tasks)
# Simulate processing tasks one by one
processed_tasks = []
while True:
task = test_queue.get_next_task()
if task is None:
break
processed_tasks.append(task)
# Simulate task processing time
time.sleep(0.01)
# Verify all tasks were processed
assert len(processed_tasks) == 5
# Verify task content
for task in processed_tasks:
assert isinstance(task, dict)
assert "tenant_id" in task
assert "resource_id" in task
assert "resource_ids" in task
assert task["tenant_id"] == test_queue.tenant_id
def test_real_world_batch_processing_scenario(self, test_queue, fake):
"""Test realistic batch processing scenario."""
# Simulate batch processing tasks
batch_tasks = []
for i in range(3):
task = {
"file_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"user_id": str(uuid4()),
"processing_config": {
"model": fake.random_element(["model_a", "model_b", "model_c"]),
"temperature": fake.random.uniform(0.1, 1.0),
"max_tokens": fake.random_int(1000, 4000)
},
"metadata": {
"source": fake.random_element(["upload", "api", "webhook"]),
"priority": fake.random_element(["low", "normal", "high"])
}
}
batch_tasks.append(task)
# Push tasks
test_queue.push_tasks(batch_tasks)
# Process tasks in batches
batch_size = 2
processed_tasks = []
while True:
batch = test_queue.pull_tasks(batch_size)
if not batch:
break
processed_tasks.extend(batch)
# Verify all tasks were processed
assert len(processed_tasks) == 3
# Verify task structure
for task in processed_tasks:
assert isinstance(task, dict)
assert "file_id" in task
assert "tenant_id" in task
assert "processing_config" in task
assert "metadata" in task
assert task["tenant_id"] == test_queue.tenant_id
class TestTenantSelfTaskQueueCompatibility:
"""Compatibility tests for migrating from legacy string-only queues."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
"""
Test compatibility with legacy queues containing only string data.
This simulates the scenario where Redis queues already contain string data
from the old architecture, and we need to ensure the new code can read them.
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "legacy_queue")
# Simulate legacy string data in Redis queue (using old format)
legacy_strings = [
"legacy_task_1",
"legacy_task_2",
"legacy_task_3",
"legacy_task_4",
"legacy_task_5"
]
# Manually push legacy strings directly to Redis (simulating old system)
for legacy_string in legacy_strings:
redis_client.lpush(queue.queue, legacy_string)
# Verify new code can read legacy string data
pulled_tasks = queue.pull_tasks(5)
assert len(pulled_tasks) == 5
# Verify all tasks are strings (not wrapped)
for task in pulled_tasks:
assert isinstance(task, str)
assert task.startswith("legacy_task_")
# Verify order (FIFO from Redis list)
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
"""
Test complete migration scenario from legacy to new system.
This simulates the real-world scenario where:
1. Legacy system has string data in Redis
2. New system starts processing the same queue
3. Both legacy and new tasks coexist during migration
4. New system can handle both formats seamlessly
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "migration_queue")
# Phase 1: Legacy system has data
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
redis_client.lpush(queue.queue, *legacy_tasks)
# Phase 2: New system starts processing legacy data
processed_legacy = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_legacy.append(task)
# Verify legacy data was processed correctly
assert len(processed_legacy) == 5
for task in processed_legacy:
assert isinstance(task, str)
assert task.startswith("legacy_resource_")
# Phase 3: New system adds new tasks (mixed types)
new_string_tasks = ["new_resource_1", "new_resource_2"]
new_object_tasks = [
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]}
},
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]}
}
]
# Push new tasks using new system
queue.push_tasks(new_string_tasks)
queue.push_tasks(new_object_tasks)
# Phase 4: Process all new tasks
processed_new = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_new.append(task)
# Verify new tasks were processed correctly
assert len(processed_new) == 4
string_tasks = [task for task in processed_new if isinstance(task, str)]
object_tasks = [task for task in processed_new if isinstance(task, dict)]
assert len(string_tasks) == 2
assert len(object_tasks) == 2
# Verify string tasks
for task in string_tasks:
assert task.startswith("new_resource_")
# Verify object tasks
for task in object_tasks:
assert isinstance(task, dict)
assert "resource_id" in task
assert "tenant_id" in task
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
"""
Test error recovery when legacy queue contains malformed data.
This ensures the new system can gracefully handle corrupted or
malformed legacy data without crashing.
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "error_recovery_queue")
# Create mix of valid and malformed legacy data
mixed_legacy_data = [
"valid_legacy_task_1",
"valid_legacy_task_2",
"malformed_data_without_prefix", # This should be treated as string
"valid_legacy_task_3",
f"{TASK_WRAPPER_PREFIX}invalid_json_data", # This should fall back to string
"valid_legacy_task_4"
]
# Manually push mixed data directly to Redis
redis_client.lpush(queue.queue, *mixed_legacy_data)
# Process all tasks
processed_tasks = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_tasks.append(task)
# Verify all tasks were processed (no crashes)
assert len(processed_tasks) == 6
# Verify all tasks are strings (malformed data falls back to string)
for task in processed_tasks:
assert isinstance(task, str)
# Verify valid tasks are preserved
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
assert len(valid_tasks) == 4
# Verify malformed data is handled gracefully
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
assert len(malformed_tasks) == 2
assert "malformed_data_without_prefix" in malformed_tasks
assert f"{TASK_WRAPPER_PREFIX}invalid_json_data" in malformed_tasks

View File

@ -268,7 +268,6 @@ class TestFeatureService:
mock_config.ENABLE_EMAIL_CODE_LOGIN = True
mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True
mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False
mock_config.ENABLE_COLLABORATION_MODE = True
mock_config.ALLOW_REGISTER = False
mock_config.ALLOW_CREATE_WORKSPACE = False
mock_config.MAIL_TYPE = "smtp"
@ -293,7 +292,6 @@ class TestFeatureService:
# Verify authentication settings
assert result.enable_email_code_login is True
assert result.enable_email_password_login is False
assert result.enable_collaboration_mode is True
assert result.is_allow_register is False
assert result.is_allow_create_workspace is False
@ -343,7 +341,6 @@ class TestFeatureService:
mock_config.ENABLE_EMAIL_CODE_LOGIN = True
mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True
mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False
mock_config.ENABLE_COLLABORATION_MODE = False
mock_config.ALLOW_REGISTER = True
mock_config.ALLOW_CREATE_WORKSPACE = True
mock_config.MAIL_TYPE = "smtp"
@ -365,7 +362,6 @@ class TestFeatureService:
assert result.enable_email_code_login is True
assert result.enable_email_password_login is True
assert result.enable_social_oauth_login is False
assert result.enable_collaboration_mode is False
assert result.is_allow_register is True
assert result.is_allow_create_workspace is True
assert result.is_email_setup is True

View File

@ -1,33 +1,17 @@
from dataclasses import asdict
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
_document_indexing, # Core function
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
document_indexing_task, # Deprecated old interface
normal_document_indexing_task, # New normal task
priority_document_indexing_task, # New priority task
)
from tasks.document_indexing_task import document_indexing_task
class TestDocumentIndexingTasks:
"""Integration tests for document indexing tasks using testcontainers.
This test class covers:
- Core _document_indexing function
- Deprecated document_indexing_task function
- New normal_document_indexing_task function
- New priority_document_indexing_task function
- Tenant queue wrapper _document_indexing_with_tenant_queue function
"""
class TestDocumentIndexingTask:
"""Integration tests for document_indexing_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
@ -240,7 +224,7 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
@ -248,11 +232,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -278,7 +261,7 @@ class TestDocumentIndexingTasks:
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
_document_indexing(non_existent_dataset_id, document_ids)
document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@ -308,18 +291,17 @@ class TestDocumentIndexingTasks:
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -351,7 +333,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@ -359,11 +341,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
@ -426,18 +407,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with all documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -490,16 +470,15 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
assert updated_document.stopped_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "error"
assert document.error is not None
assert "batch upload" in document.error
assert document.stopped_at is not None
# Verify no indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@ -524,18 +503,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
@ -563,7 +541,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@ -571,314 +549,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
# Act: Execute the deprecated task (it only takes 2 parameters)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_normal_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test normal_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new normal task
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_priority_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test priority_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new priority task
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_document_indexing_with_tenant_queue_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with no waiting tasks.
This test verifies:
- Core indexing logic execution (same as _document_indexing)
- Tenant queue cleanup when no waiting tasks
- Task function parameter passing
- Queue management after processing
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0]
assert len(processed_documents) == 2
# Verify task function was not called (no waiting tasks)
mock_task_func.delay.assert_not_called()
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
This test verifies:
- Core indexing logic execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dateset_id = dataset.id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting tasks to the real Redis queue
waiting_tasks = [
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"])
]
# Convert DocumentTask objects to dictionaries for serialization
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
queue.push_tasks(waiting_task_dicts)
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
assert len(remaining_tasks) == 1
def test_document_indexing_with_tenant_queue_error_handling(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dateset_id = dataset.id
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting task to the real Redis queue
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
queue.push_tasks([asdict(waiting_task)])
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_document_indexing_with_tenant_queue_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
dataset1, documents1 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
dataset2, documents2 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
tenant1_id = dataset1.tenant_id
tenant2_id = dataset2.tenant_id
dataset1_id = dataset1.id
dataset2_id = dataset2.id
document_ids1 = [doc.id for doc in documents1]
document_ids2 = [doc.id for doc in documents2]
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create queue instances for both tenants
queue1 = TenantSelfTaskQueue(tenant1_id, "document_indexing")
queue2 = TenantSelfTaskQueue(tenant2_id, "document_indexing")
# Add waiting tasks to both queues
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
queue1.push_tasks([asdict(waiting_task1)])
queue2.push_tasks([asdict(waiting_task2)])
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None

View File

@ -1,939 +0,0 @@
import json
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
priority_rag_pipeline_run_task,
run_single_rag_pipeline_task,
)
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
class TestRagPipelineRunTasks:
"""Integration tests for RAG pipeline run tasks using testcontainers.
This test class covers:
- priority_rag_pipeline_run_task function
- rag_pipeline_run_task function
- run_single_rag_pipeline_task function
- Real Redis-based TenantSelfTaskQueue operations
- PipelineGenerator._generate method mocking and parameter validation
- File operations and cleanup
- Error handling and queue management
"""
@pytest.fixture
def mock_pipeline_generator(self):
"""Mock PipelineGenerator._generate method."""
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
# Mock the _generate method to return a simple response
mock_generate.return_value = {
"answer": "Test response",
"metadata": {"test": "data"}
}
yield mock_generate
@pytest.fixture
def mock_file_service(self):
"""Mock FileService for file operations."""
with (
patch("services.file_service.FileService.get_file_content") as mock_get_content,
patch("services.file_service.FileService.delete_file") as mock_delete_file,
):
yield {
"get_content": mock_get_content,
"delete_file": mock_delete_file,
}
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
"""
Helper method to create test pipeline and workflow for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant, pipeline, workflow) - Created entities
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph="{}",
features="{}",
marked_name=fake.company(),
marked_comment=fake.text(max_nb_chars=100),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db.session.add(workflow)
db.session.commit()
# Create pipeline
pipeline = Pipeline(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
workflow_id=workflow.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
created_by=account.id,
)
db.session.add(pipeline)
db.session.commit()
# Refresh entities to ensure they're properly loaded
db.session.refresh(account)
db.session.refresh(tenant)
db.session.refresh(workflow)
db.session.refresh(pipeline)
return account, tenant, pipeline, workflow
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
"""
Helper method to create RAG pipeline invoke entities for testing.
Args:
account: Account instance
tenant: Tenant instance
pipeline: Pipeline instance
workflow: Workflow instance
count: Number of entities to create
Returns:
list: List of RagPipelineInvokeEntity instances
"""
fake = Faker()
entities = []
for i in range(count):
# Create application generate entity
app_config = {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
}
application_generate_entity = {
"task_id": str(uuid.uuid4()),
"app_config": app_config,
"inputs": {"query": f"Test query {i}"},
"files": [],
"user_id": account.id,
"stream": False,
"invoke_from": "published",
"workflow_execution_id": str(uuid.uuid4()),
"pipeline_config": {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
},
"datasource_type": "upload_file",
"datasource_info": {},
"dataset_id": str(uuid.uuid4()),
"batch": "test_batch",
}
entity = RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
application_generate_entity=application_generate_entity,
user_id=account.id,
tenant_id=tenant.id,
workflow_id=workflow.id,
streaming=False,
workflow_execution_id=str(uuid.uuid4()),
workflow_thread_pool_id=str(uuid.uuid4()),
)
entities.append(entity)
return entities
def _create_file_content_for_entities(self, entities):
"""
Helper method to create file content for RAG pipeline invoke entities.
Args:
entities: List of RagPipelineInvokeEntity instances
Returns:
str: JSON string containing serialized entities
"""
entities_data = [entity.model_dump() for entity in entities]
return json.dumps(entities_data)
def test_priority_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful priority RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 2
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful regular RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 3
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
def test_rag_pipeline_run_task_legacy_compatibility(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
This test simulates the scenario where:
- Old code writes file IDs directly to Redis list using lpush
- New worker processes these legacy queue entries
- Ensures backward compatibility during deployment transition
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
New format: TenantSelfTaskQueue.push_tasks([file_id])
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Simulate legacy Redis queue format - direct file IDs in Redis list
from extensions.ext_redis import redis_client
# Legacy queue key format (old code)
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
# Add legacy format data to Redis (simulating old code behavior)
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
for file_id_legacy in legacy_file_ids:
redis_client.lpush(legacy_queue_key, file_id_legacy)
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == legacy_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantSelfTaskQueue should be able to read from the legacy format
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
# Cleanup: Remove legacy test data
redis_client.delete(legacy_queue_key)
redis_client.delete(legacy_task_key)
def test_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
def test_priority_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in priority RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task (should not raise exception)
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in regular RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_priority_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in priority RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantSelfTaskQueue
queue1 = TenantSelfTaskQueue(tenant1.id, "pipeline")
queue2 = TenantSelfTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task for tenant1 only
priority_rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id1
assert call_kwargs.get('tenant_id') == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
def test_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in regular RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantSelfTaskQueue
queue1 = TenantSelfTaskQueue(tenant1.id, "pipeline")
queue2 = TenantSelfTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id1
assert call_kwargs.get('tenant_id') == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test successful run_single_rag_pipeline_task execution.
This test verifies:
- Single RAG pipeline task execution within Flask app context
- Entity validation and database queries
- PipelineGenerator._generate method call with correct parameters
- Proper Flask context handling
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
entity_data = entities[0].model_dump()
# Act: Execute the single task
with flask_app_with_containers.app_context():
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Verify expected outcomes
# Verify PipelineGenerator._generate was called
assert mock_pipeline_generator.call_count == 1
# Verify call parameters
call = mock_pipeline_generator.call_args
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with invalid entity data.
This test verifies:
- Proper error handling for invalid entity data
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create entity data with valid UUIDs but non-existent entities
fake = Faker()
invalid_entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with non-existent database entities.
This test verifies:
- Proper error handling for missing database entities
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create test data with non-existent IDs
fake = Faker()
entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_priority_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act & Assert: Execute the priority task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0

Some files were not shown because too many files have changed in this diff Show More