Merge remote-tracking branch 'origin/deploy/dev' into feat/memory-orchestration-be-dev-env

# Conflicts:
#	api/models/__init__.py
#	api/uv.lock
This commit is contained in:
Stream
2025-10-11 15:01:26 +08:00
141 changed files with 11908 additions and 3514 deletions

View File

@ -1,3 +1,4 @@
import os
import sys
@ -8,10 +9,16 @@ def is_db_command():
# create app
celery = None
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
@ -33,8 +40,15 @@ else:
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
socketio_app, flask_app = create_app()
app = flask_app
celery = flask_app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> DifyApp:
def create_app() -> tuple[any, DifyApp]:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
import socketio
from extensions.ext_socketio import sio
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
return socketio_app, app
def initialize_extensions(app: DifyApp):

View File

@ -836,6 +836,16 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""

View File

@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@ -70,11 +75,6 @@ class HostedOpenAiConfig(BaseSettings):
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
default=False,
@ -98,6 +98,129 @@ class HostedOpenAiConfig(BaseSettings):
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
@ -144,16 +267,32 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings):
"""
@ -250,5 +389,8 @@ class HostedServiceConfig(
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@ -58,11 +58,13 @@ from .app import (
mcp_server,
message,
model_config,
online_user,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@ -106,10 +108,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -143,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -196,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"version",
"website",
"workflow",

View File

@ -15,7 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
def admin_required(view: Callable[P, R]):
@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource):
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
"can_trial": fields.Boolean(required=True, description="Can trial"),
"trial_limit": fields.Integer(required=True, description="Trial limit"),
},
)
)
@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = args["category"]
recommended_app.position = args["position"]
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBanner(Resource):
@api.doc("insert_explore_banner")
@api.doc(description="Insert an explore banner")
@api.expect(
api.model(
"InsertExploreBannerRequest",
{
"content": fields.String(required=True, description="Banner content"),
"link": fields.String(required=True, description="Banner link"),
"sort": fields.Integer(required=True, description="Banner sort"),
},
)
)
@api.response(200, "Banner inserted successfully")
@admin_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
banner = ExporleBanner(
content=args["content"],
link=args["link"],
sort=args["sort"],
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 200
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBanner(Resource):
@api.doc("delete_explore_banner")
@api.doc(description="Delete an explore banner")
@api.response(204, "Banner deleted successfully")
@admin_required
@only_edition_cloud
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -0,0 +1,291 @@
import json
import time
from extensions.ext_redis import redis_client
from extensions.ext_socketio import sio
from libs.passport import PassportService
from services.account_service import AccountService
@sio.on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
token = None
if auth and isinstance(auth, dict):
token = auth.get("token")
if not token:
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
return False
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
return True
except Exception:
return False
@sio.on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
session = sio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return {"msg": "unauthorized"}, 401
# Each session is stored independently with sid as key
session_info = {
"user_id": user_id,
"username": session.get("username", "Unknown"),
"avatar": session.get("avatar", None),
"sid": sid,
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
}
# Store session info with sid as key
redis_client.hset(f"workflow_online_users:{workflow_id}", sid, json.dumps(session_info))
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
# Leader election: first session becomes the leader
leader_sid = get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid
sio.enter_room(sid, workflow_id)
broadcast_online_users(workflow_id)
# Notify this session of their leader status
sio.emit("status", {"isLeader": is_leader}, room=sid)
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@sio.on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if mapping:
data = json.loads(mapping)
workflow_id = data["workflow_id"]
# Remove this specific session
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
# Handle leader re-election if the leader session disconnected
handle_leader_disconnect(workflow_id, sid)
broadcast_online_users(workflow_id)
def _clear_session_state(workflow_id: str, sid: str) -> None:
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
def _is_session_active(workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not sio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not redis_client.hexists(f"workflow_online_users:{workflow_id}", sid):
return False
if not redis_client.exists(f"ws_sid_map:{sid}"):
return False
return True
def get_or_set_leader(workflow_id: str, sid: str) -> str:
"""
Get current leader session or set this session as leader if no valid leader exists.
Returns the leader session id (sid).
"""
leader_key = f"workflow_leader:{workflow_id}"
raw_leader = redis_client.get(leader_key)
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
leader_replaced = False
if current_leader and not _is_session_active(workflow_id, current_leader):
_clear_session_state(workflow_id, current_leader)
redis_client.delete(leader_key)
current_leader = None
leader_replaced = True
if not current_leader:
redis_client.set(leader_key, sid, ex=3600) # Expire in 1 hour
if leader_replaced:
broadcast_leader_change(workflow_id, sid)
return sid
return current_leader
def handle_leader_disconnect(workflow_id, disconnected_sid):
"""
Handle leader re-election when a session disconnects.
If the disconnected session was the leader, elect a new leader from remaining sessions.
"""
leader_key = f"workflow_leader:{workflow_id}"
current_leader = redis_client.get(leader_key)
if current_leader:
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
if current_leader == disconnected_sid:
# Leader session disconnected, elect a new leader
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
if sessions_json:
# Get the first remaining session as new leader
new_leader_sid = list(sessions_json.keys())[0]
if isinstance(new_leader_sid, bytes):
new_leader_sid = new_leader_sid.decode("utf-8")
redis_client.set(leader_key, new_leader_sid, ex=3600)
# Notify all sessions about the new leader
broadcast_leader_change(workflow_id, new_leader_sid)
else:
# No sessions left, remove leader
redis_client.delete(leader_key)
def broadcast_leader_change(workflow_id, new_leader_sid):
"""
Broadcast leader change to all sessions in the workflow.
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
for sid, session_info_json in sessions_json.items():
try:
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
is_leader = sid_str == new_leader_sid
# Emit to each session whether they are the new leader
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
except Exception:
continue
def get_current_leader(workflow_id):
"""
Get the current leader for a workflow.
"""
leader_key = f"workflow_leader:{workflow_id}"
leader = redis_client.get(leader_key)
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
def broadcast_online_users(workflow_id):
"""
Broadcast online users to the workflow room.
Each session is shown as a separate user (even if same person has multiple tabs).
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for sid, session_info_json in sessions_json.items():
try:
session_info = json.loads(session_info_json)
# Each session appears as a separate "user" in the UI
users.append(
{
"user_id": session_info["user_id"],
"username": session_info["username"],
"avatar": session_info.get("avatar"),
"sid": session_info["sid"],
"connected_at": session_info.get("connected_at"),
}
)
except Exception:
continue
# Sort by connection time to maintain consistent order
users.sort(key=lambda x: x.get("connected_at") or 0)
# Get current leader session
leader_sid = get_current_leader(workflow_id)
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
@sio.on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouseMove
2. varsAndFeaturesUpdate
3. syncRequest(ask leader to update graph)
4. appStateUpdate
5. mcpServerUpdate
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
user_id = mapping_data["user_id"]
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
sio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}
@sio.on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}

View File

@ -21,7 +21,9 @@ from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -127,6 +129,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("force_upload", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -143,6 +146,7 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"force_upload": data.get("force_upload", False),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -171,6 +175,7 @@ class DraftWorkflowApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
force_upload=args.get("force_upload", False),
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -796,6 +801,45 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
parser = reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, location="json")
args = parser.parse_args()
features = args.get("features")
# Update draft workflow features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows")
@ -985,3 +1029,105 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
class WorkflowOnlineUsersApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workflow_ids", type=str, required=True, location="args")
args = parser.parse_args()
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
return {"data": results}
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
WorkflowFeaturesApi,
"/apps/<uuid:app_id>/workflows/draft/features",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")

View File

@ -0,0 +1,240 @@
import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import account_with_role_fields
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_basic_fields, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_create_fields)
def post(self, app_model: App):
"""Create a new workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("position_x", type=float, required=True, location="json")
parser.add_argument("position_y", type=float, required=True, location="json")
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_detail_fields)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_update_fields)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("position_x", type=float, required=False, location="json")
parser.add_argument("position_y", type=float, required=False, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_resolve_fields)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_create_fields)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=args.content,
created_by=current_user.id,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_update_fields)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
reply = WorkflowCommentService.update_reply(
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
)
return reply
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"users": members}
# Register API routes
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
api.add_resource(
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
)
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
@ -353,7 +353,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@ -446,8 +446,35 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables")
@api.doc(description="Get system variables for workflow")
@ -497,3 +524,44 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("environment_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

View File

@ -0,0 +1,34 @@
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
banners = (
db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all()
)
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -27,6 +27,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -0,0 +1,375 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common import fields
from controllers.common.fields import build_site_model
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.service_api import service_api_ns
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model
@marshal_with(fields.parameters_fields)
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppApi(Resource):
@trial_feature_enable
@get_app_model
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")

View File

@ -2,15 +2,16 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -83,3 +137,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -33,6 +33,7 @@ from controllers.console.wraps import (
only_edition_cloud,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
@ -135,6 +136,17 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="args")
args = parser.parse_args()
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
return {"avatar_url": avatar_url}
@setup_required
@login_required
@account_initialization_required

View File

@ -51,6 +51,8 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
}
tenants_fields = {

View File

@ -56,6 +56,9 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.moderation_config = self.init_moderation_config()
@ -128,7 +131,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@ -156,18 +159,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@ -185,6 +219,66 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@ -618,9 +618,9 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type == ProviderQuotaType.TRIAL:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
@ -628,8 +628,8 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit, # type: ignore
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@ -642,7 +642,7 @@ class ProviderManager:
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -652,7 +652,7 @@ class ProviderManager:
existed_provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict
@ -912,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@ -932,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -38,14 +38,16 @@ elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
export PORT=${DIFY_PORT:-5001}
exec python -m app
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
app:socketio_app
fi
fi

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@ -133,22 +133,38 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()

View File

@ -0,0 +1,3 @@
import socketio
sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*")

View File

@ -0,0 +1,17 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@ -0,0 +1,96 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@ -0,0 +1,90 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: 68519ad5cd18
Create Date: 2025-08-22 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '227822d22895'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_comments',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('position_x', sa.Float(), nullable=False),
sa.Column('position_y', sa.Float(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('resolved_at', sa.DateTime(), nullable=True),
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
)
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_replies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
)
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_mentions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
)
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.drop_index('comment_mentions_user_idx')
batch_op.drop_index('comment_mentions_reply_idx')
batch_op.drop_index('comment_mentions_comment_idx')
op.drop_table('workflow_comment_mentions')
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.drop_index('comment_replies_created_at_idx')
batch_op.drop_index('comment_replies_comment_idx')
op.drop_table('workflow_comment_replies')
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.drop_index('workflow_comments_created_at_idx')
batch_op.drop_index('workflow_comments_app_idx')
op.drop_table('workflow_comments')
# ### end Alembic commands ###

View File

@ -0,0 +1,79 @@
"""add table explore banner and trial
Revision ID: 1b435d90db42
Revises: cf7c38a32b2d
Create Date: 2025-09-19 14:42:58.416649
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1b435d90db42'
down_revision = 'cf7c38a32b2d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@ -0,0 +1,104 @@
"""add table credit pool
Revision ID: 58a70d22fdbd
Revises: 68519ad5cd18
Create Date: 2025-09-25 15:20:40.367078
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '58a70d22fdbd'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move quota data from providers to tenant_credit_pools
migrate_quota_data()
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###
def migrate_quota_data():
"""
Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
"""
# Create connection
bind = op.get_bind()
# Define quota type mappings
quota_type_mappings = ['trial', 'paid']
for quota_type in quota_type_mappings:
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = :quota_type
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""")
existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type
}).scalar()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -10,6 +10,11 @@ from .account import (
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
from .comment import (
WorkflowComment,
WorkflowCommentMention,
WorkflowCommentReply,
)
from .dataset import (
AppDatasetJoin,
Dataset,
@ -29,6 +34,7 @@ from .dataset import (
)
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@ -41,6 +47,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
Message,
@ -54,7 +61,9 @@ from .model import (
Site,
Tag,
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -99,6 +108,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@ -135,6 +145,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@ -163,6 +174,7 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
"TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@ -172,12 +184,16 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowComment",
"WorkflowCommentMention",
"WorkflowCommentReply",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",

189
api/models/comment.py Normal file
View File

@ -0,0 +1,189 @@
"""Workflow comment models."""
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
if TYPE_CHECKING:
pass
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
@property
def resolved_by_account(self):
"""Get resolver account."""
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[Optional[str]] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
return db.session.get(Account, self.mentioned_user_id)

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@ -581,6 +581,63 @@ class InstalledApp(Base):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class OAuthProviderApp(Base):
"""
Globally shared OAuth provider app information.
@ -1944,3 +2001,29 @@ class TraceAppConfig(Base):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
class TenantCreditPool(Base):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def remaining_credits(self) -> int:
return max(0, self.quota_limit - self.quota_used)
def has_sufficient_credits(self, required_credits: int) -> bool:
return self.remaining_credits >= required_credits

View File

@ -342,7 +342,7 @@ class Workflow(Base):
:return: hash
"""
entity = {"graph": self.graph_dict, "features": self.features_dict}
entity = {"graph": self.graph_dict}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))

View File

@ -20,6 +20,7 @@ dependencies = [
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent-websocket~=0.10.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@ -68,6 +69,7 @@ dependencies = [
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"python-socketio~=5.13.0",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.1.0",
@ -86,6 +88,7 @@ dependencies = [
"sendgrid~=6.12.3",
"flask-restx~=1.3.0",
"packaging~=23.2",
"gevent-websocket>=0.10.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -995,6 +995,11 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
from services.credit_pool_service import CreditPoolService
CreditPoolService.create_default_pool(tenant.id)
return tenant
@staticmethod

View File

@ -0,0 +1,68 @@
import logging
from typing import Optional
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
logger = logging.getLogger(__name__)
class CreditPoolService:
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(credit_pool)
db.session.commit()
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
):
"""check and deduct credits"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits < credits_required:
raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
)
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
where_conditions = [
TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
session.execute(stmt)
session.commit()
except Exception:
raise QuotaExceededError("Failed to deduct credits")

View File

@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
enable_trial_app: bool = False
enable_explore_banner: bool = False
class FeatureService:
@ -214,6 +216,8 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):

View File

@ -1,4 +1,9 @@
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -20,6 +25,15 @@ class RecommendedAppService:
)
)
if FeatureService.get_system_features().enable_trial_app:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
app["can_trial"] = True
else:
app["can_trial"] = False
return result
@classmethod
@ -32,4 +46,27 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
result["can_trial"] = True
else:
result["can_trial"] = False
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
with Session(db.engine) as session:
account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first()
if account_trial_app_record:
account_trial_app_record.count += 1
session.commit()
else:
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
session.commit()

View File

@ -0,0 +1,311 @@
import logging
from typing import Optional
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
return comments
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: Optional[list[str]] = None,
) -> WorkflowComment:
"""Create a new workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
session.commit()
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: Optional[float] = None,
position_y: Optional[float] = None,
mentioned_user_ids: Optional[list[str]] = None,
) -> dict:
"""Update a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
session.commit()
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None
) -> dict:
"""Add a reply to a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
# Create mention linking to specific reply
mention = WorkflowCommentMention(
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
)
session.add(mention)
session.commit()
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(
reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None
) -> WorkflowCommentReply:
"""Update a comment reply."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@ -198,15 +198,17 @@ class WorkflowService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
force_upload: bool = False,
) -> Workflow:
"""
Sync draft workflow
:param force_upload: Skip hash validation when True (for restore operations)
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow and workflow.unique_hash != unique_hash:
if workflow and workflow.unique_hash != unique_hash and not force_upload:
raise WorkflowHashNotEqualError()
# validate features structure
@ -244,6 +246,78 @@ class WorkflowService:
# return draft workflow
return workflow
def update_draft_workflow_environment_variables(
self,
*,
app_model: App,
environment_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow environment variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.environment_variables = environment_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_conversation_variables(
self,
*,
app_model: App,
conversation_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow conversation variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.conversation_variables = conversation_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_features(
self,
*,
app_model: App,
features: dict,
account: Account,
):
"""
Update draft workflow features
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def publish_workflow(
self,
*,

View File

@ -46,5 +46,17 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info

4803
api/uv.lock generated

File diff suppressed because it is too large Load Diff