mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
Compare commits
306 Commits
bbf1247f80
...
feat/memor
| Author | SHA1 | Date | |
|---|---|---|---|
| dd7efb11ff | |||
| bfba8bec2d | |||
| 8aa4db0c77 | |||
| 65a3646ce7 | |||
| 106f0fba0b | |||
| cb73335599 | |||
| f4fa57dac9 | |||
| 06f364f2c8 | |||
| 7ca06931ec | |||
| db83b54a88 | |||
| f4567fbf9e | |||
| 8fd088754a | |||
| e7d63a9fa3 | |||
| a1e3a72274 | |||
| 1a4600ce77 | |||
| 61d9428064 | |||
| f6038a4557 | |||
| c09d205776 | |||
| b4f9698289 | |||
| c9a2dc0b13 | |||
| b6114266af | |||
| 33b576b9d5 | |||
| 63de2cc7a0 | |||
| c9c3aa0810 | |||
| dee4399060 | |||
| d3ea98037e | |||
| 2c408445ff | |||
| 8d2b5c5464 | |||
| 97b5d4bba1 | |||
| 68c7e43d8c | |||
| e0af930acd | |||
| dfc8bc4aec | |||
| 63b4bca7d8 | |||
| 517f8aafdc | |||
| d05ba90779 | |||
| b6620c1f42 | |||
| b60e1f4222 | |||
| 4df606f439 | |||
| 2a5a497f15 | |||
| 3f1da39aee | |||
| a52edc6cc1 | |||
| c367f80ec5 | |||
| da353a42da | |||
| 9b0f172f91 | |||
| 85dfc013ea | |||
| 5a1fae1171 | |||
| 33d4c95470 | |||
| 659cbc05a9 | |||
| 6ce65de2cd | |||
| 93b2eb3ff6 | |||
| bf71300635 | |||
| 37ecd4a0bc | |||
| 827a1b181b | |||
| c4e7cb75cd | |||
| 98e4bfcda8 | |||
| ee48ca7671 | |||
| 4ba6de1116 | |||
| bfbe636555 | |||
| 930fdc8fb4 | |||
| fc90a8fb32 | |||
| 791f33fd0b | |||
| 1e0a3b163e | |||
| bb1f1a56a5 | |||
| 15be85514d | |||
| 80cb85b845 | |||
| 0b131f1a8c | |||
| e51f2d68cb | |||
| a6d4bf3399 | |||
| 113aa4ae08 | |||
| 5b40bf6d4e | |||
| 06ad8efd89 | |||
| 8513afbcf6 | |||
| 54ae43ef47 | |||
| 7a74b5ee3e | |||
| 0641773395 | |||
| d745c2e8e3 | |||
| e974c696f7 | |||
| 2ff280c4bf | |||
| 0e9d43d605 | |||
| cc54363c27 | |||
| c41140d654 | |||
| 89affe3139 | |||
| f7fd065bee | |||
| 96c7c86e9d | |||
| 2c4977dbb1 | |||
| e240175116 | |||
| 2398ed6fe8 | |||
| a8420ac33c | |||
| 8470be6411 | |||
| 3d6295c622 | |||
| ff2f7206f3 | |||
| b937fc8978 | |||
| 86a9a51952 | |||
| 4188c9a1dd | |||
| 8833fee232 | |||
| 5bf642c3f9 | |||
| 8c00f89e36 | |||
| 9e8ac5c96b | |||
| 3d7d4182a6 | |||
| 75c221038d | |||
| b7b5b0b8d0 | |||
| 05a67f4716 | |||
| 6eab6a675c | |||
| f49476a206 | |||
| c1e9c56e25 | |||
| d5dd73cacf | |||
| 21f7a49b4e | |||
| 716ac04e13 | |||
| c28a32fc47 | |||
| 31cba28e8a | |||
| 48cd7e6481 | |||
| 47aba1c9f9 | |||
| d94e598a89 | |||
| 0f3f8bc0d9 | |||
| e0df12c212 | |||
| eb448d9bb8 | |||
| 0ba77f13db | |||
| f0a2eb843c | |||
| 28acb70118 | |||
| 7c35aaa99d | |||
| 82fcf1da64 | |||
| c662b95b80 | |||
| a8c2a300f6 | |||
| d654d9d8b1 | |||
| eedc0ca6ea | |||
| e4c3213978 | |||
| 520bc55da5 | |||
| ebbb82178f | |||
| d60a9ac63a | |||
| f1486224e9 | |||
| 88b02d5a07 | |||
| 394b7d09b8 | |||
| e9313b9c1b | |||
| 5cf3d9e4d9 | |||
| 41958f55cd | |||
| 600ad232e1 | |||
| 7a3825cfce | |||
| 9519653422 | |||
| efa2307c73 | |||
| 068fa3d0e3 | |||
| 13d8dbd542 | |||
| e59cc3311d | |||
| 4f4e94f753 | |||
| 258970f489 | |||
| 781abe5f8f | |||
| b442ba8b2b | |||
| 10e36d2355 | |||
| 13c53fedad | |||
| 4bda1bd884 | |||
| 3abe7850d6 | |||
| b50284d864 | |||
| 81c6e52401 | |||
| 847d257366 | |||
| 687662cf1f | |||
| 6432d98469 | |||
| 088ccf8b8d | |||
| e8683bf957 | |||
| 4653981b6b | |||
| e2547413d3 | |||
| ea17f41b5b | |||
| 29178d8adf | |||
| 7e86ead574 | |||
| 72debcb228 | |||
| 72737dabc7 | |||
| f6e5cb4381 | |||
| ffad3b5fb1 | |||
| cba9fc3020 | |||
| e776accaf3 | |||
| 3eac26929a | |||
| 4d3adec738 | |||
| ac5dd1f45a | |||
| 3005cf3282 | |||
| 54b272206e | |||
| 89bed479e4 | |||
| fdd673a3a9 | |||
| 22f6d285c7 | |||
| 10aa16b471 | |||
| 3d761a3189 | |||
| e3903f34e4 | |||
| f4f055fb36 | |||
| b3838581fd | |||
| affbe7ccdb | |||
| 8563ae5511 | |||
| 2c765ccfae | |||
| 626e7b2211 | |||
| 516b6b0fa8 | |||
| 613d086f1e | |||
| 9e0630f012 | |||
| d6d9554954 | |||
| dd8577f832 | |||
| 2a532ab729 | |||
| 03eef65b25 | |||
| ad07d63994 | |||
| 8685f055ea | |||
| 3b868a1cec | |||
| ab389eaa8e | |||
| 008f778e8f | |||
| d7f5da5df4 | |||
| 9fda130b3a | |||
| 72cdbdba0f | |||
| b92a153902 | |||
| 9f2927979b | |||
| 75257232c3 | |||
| 1721314c62 | |||
| fc230bcc59 | |||
| b4636ddf44 | |||
| b1140301a4 | |||
| 58cd785da6 | |||
| 2035186cd2 | |||
| 53ba6aadff | |||
| f091868b7c | |||
| 89bedae0d3 | |||
| c8acc48976 | |||
| 21fee59b22 | |||
| 957a8253f8 | |||
| d5fc3e7bed | |||
| ab438b42da | |||
| 3867fece4a | |||
| 2b908d4fbe | |||
| 8ff062ec8b | |||
| 294fc41aec | |||
| 684f7df158 | |||
| c3287755e3 | |||
| 9f97f4d79e | |||
| 34eb421649 | |||
| 850b05573e | |||
| 6ec8bfdfee | |||
| 81638c248e | |||
| 2e11b1298e | |||
| 20320f3a27 | |||
| 4019c12d26 | |||
| cf72184ce4 | |||
| ca8d15bc64 | |||
| a91c897fd3 | |||
| 816bdf0320 | |||
| d4a6acbd99 | |||
| e421db4005 | |||
| 6af168cb31 | |||
| 29f56cf0cf | |||
| 11b6ea742d | |||
| 05d231ad33 | |||
| 48f3c69c69 | |||
| 9067c2a9c1 | |||
| 8b68020453 | |||
| 9f7321ca1a | |||
| 5fa01132b9 | |||
| 4d2fc66a8d | |||
| f72ed4898c | |||
| e082b6d599 | |||
| d44be2d835 | |||
| 85a73181cc | |||
| e31e4ab677 | |||
| 0d95c2192e | |||
| 7dc8557033 | |||
| 1fa8b26e55 | |||
| 4b085d46f6 | |||
| 72037a1865 | |||
| 635c4ed4ce | |||
| 7ffcf8dd6f | |||
| 97cd21d3be | |||
| a13cb7e1c5 | |||
| 7b602e9003 | |||
| 5a26ebec8f | |||
| 8341b8b1c1 | |||
| bbb640c9a2 | |||
| 0c97bbf137 | |||
| 45fddc70d5 | |||
| f977dc410a | |||
| d535818505 | |||
| fcf4e1f37d | |||
| 38130c8502 | |||
| f284c91988 | |||
| 584b2cefa3 | |||
| 42091b4a79 | |||
| 2d1621c43d | |||
| d1a5db3310 | |||
| ad8fd8fecc | |||
| be74b76079 | |||
| dd64af728f | |||
| e43b46786d | |||
| 3f3b37b843 | |||
| 2ecf9f6ddf | |||
| 48c069fe68 | |||
| 9c5c597c85 | |||
| c2eec8545d | |||
| 2395d4be26 | |||
| 9455476705 | |||
| 494e223706 | |||
| 348fd18230 | |||
| 7233b4de55 | |||
| af6df05685 | |||
| 965b65db6e | |||
| 4cc01c8aa8 | |||
| 41372168b6 | |||
| f4438b0a08 | |||
| 897c842637 | |||
| ee86ceb906 | |||
| e298732499 | |||
| 4081937e22 | |||
| f9aedb2118 | |||
| 74b4719af8 | |||
| 2f35cc9188 | |||
| 2f966d8c38 | |||
| b0868d9136 | |||
| 37440e9416 | |||
| 0d7d27ec0b |
20
api/app.py
20
api/app.py
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -8,10 +9,16 @@ def is_db_command():
|
||||
|
||||
|
||||
# create app
|
||||
celery = None
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
@ -33,8 +40,15 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = flask_app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> DifyApp:
|
||||
def create_app() -> tuple[any, DifyApp]:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
import socketio
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
return socketio_app, app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@ -836,6 +836,16 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
HOSTED_POOL_CREDITS: int = Field(
|
||||
description="Pool credits for hosted service",
|
||||
default=200,
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
@ -70,11 +75,6 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted OpenAI service usage",
|
||||
default=200,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted OpenAI service",
|
||||
default=False,
|
||||
@ -98,6 +98,129 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedGeminiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Gemini service
|
||||
"""
|
||||
|
||||
HOSTED_GEMINI_API_KEY: str | None = Field(
|
||||
description="API key for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Gemini API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
|
||||
class HostedXAIConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching XAI service
|
||||
"""
|
||||
|
||||
HOSTED_XAI_API_KEY: str | None = Field(
|
||||
description="API key for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted XAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedDeepseekConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Deepseek service
|
||||
"""
|
||||
|
||||
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
|
||||
description="API key for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Deepseek API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Deepseek service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for hosted Azure OpenAI service
|
||||
@ -144,16 +267,32 @@ class HostedAnthropicConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted Anthropic service usage",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
@ -250,5 +389,8 @@ class HostedServiceConfig(
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
HostedGeminiConfig,
|
||||
HostedXAIConfig,
|
||||
HostedDeepseekConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -58,11 +58,13 @@ from .app import (
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
online_user,
|
||||
ops_trace,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@ -106,10 +108,12 @@ from .datasets.rag_pipeline import (
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
@ -143,6 +147,7 @@ __all__ = [
|
||||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
@ -196,6 +201,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
||||
@ -15,7 +15,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource):
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
"can_trial": fields.Boolean(required=True, description="Can trial"),
|
||||
"trial_limit": fields.Integer(required=True, description="Trial limit"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
|
||||
parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource):
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource):
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBanner(Resource):
|
||||
@api.doc("insert_explore_banner")
|
||||
@api.doc(description="Insert an explore banner")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreBannerRequest",
|
||||
{
|
||||
"content": fields.String(required=True, description="Banner content"),
|
||||
"link": fields.String(required=True, description="Banner link"),
|
||||
"sort": fields.Integer(required=True, description="Banner sort"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Banner inserted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=args["content"],
|
||||
link=args["link"],
|
||||
sort=args["sort"],
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBanner(Resource):
|
||||
@api.doc("delete_explore_banner")
|
||||
@api.doc(description="Delete an explore banner")
|
||||
@api.response(204, "Banner deleted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
291
api/controllers/console/app/online_user.py
Normal file
291
api/controllers/console/app/online_user.py
Normal file
@ -0,0 +1,291 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@sio.on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@sio.on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
session = sio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
# Each session is stored independently with sid as key
|
||||
session_info = {
|
||||
"user_id": user_id,
|
||||
"username": session.get("username", "Unknown"),
|
||||
"avatar": session.get("avatar", None),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
|
||||
}
|
||||
|
||||
# Store session info with sid as key
|
||||
redis_client.hset(f"workflow_online_users:{workflow_id}", sid, json.dumps(session_info))
|
||||
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
|
||||
|
||||
# Leader election: first session becomes the leader
|
||||
leader_sid = get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid
|
||||
|
||||
sio.enter_room(sid, workflow_id)
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
# Notify this session of their leader status
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@sio.on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
mapping = redis_client.get(f"ws_sid_map:{sid}")
|
||||
if mapping:
|
||||
data = json.loads(mapping)
|
||||
workflow_id = data["workflow_id"]
|
||||
|
||||
# Remove this specific session
|
||||
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
|
||||
redis_client.delete(f"ws_sid_map:{sid}")
|
||||
|
||||
# Handle leader re-election if the leader session disconnected
|
||||
handle_leader_disconnect(workflow_id, sid)
|
||||
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
|
||||
def _clear_session_state(workflow_id: str, sid: str) -> None:
|
||||
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
|
||||
redis_client.delete(f"ws_sid_map:{sid}")
|
||||
|
||||
|
||||
def _is_session_active(workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not sio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not redis_client.hexists(f"workflow_online_users:{workflow_id}", sid):
|
||||
return False
|
||||
|
||||
if not redis_client.exists(f"ws_sid_map:{sid}"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_or_set_leader(workflow_id: str, sid: str) -> str:
|
||||
"""
|
||||
Get current leader session or set this session as leader if no valid leader exists.
|
||||
Returns the leader session id (sid).
|
||||
"""
|
||||
leader_key = f"workflow_leader:{workflow_id}"
|
||||
raw_leader = redis_client.get(leader_key)
|
||||
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
|
||||
leader_replaced = False
|
||||
|
||||
if current_leader and not _is_session_active(workflow_id, current_leader):
|
||||
_clear_session_state(workflow_id, current_leader)
|
||||
redis_client.delete(leader_key)
|
||||
current_leader = None
|
||||
leader_replaced = True
|
||||
|
||||
if not current_leader:
|
||||
redis_client.set(leader_key, sid, ex=3600) # Expire in 1 hour
|
||||
if leader_replaced:
|
||||
broadcast_leader_change(workflow_id, sid)
|
||||
return sid
|
||||
|
||||
return current_leader
|
||||
|
||||
|
||||
def handle_leader_disconnect(workflow_id, disconnected_sid):
|
||||
"""
|
||||
Handle leader re-election when a session disconnects.
|
||||
If the disconnected session was the leader, elect a new leader from remaining sessions.
|
||||
"""
|
||||
leader_key = f"workflow_leader:{workflow_id}"
|
||||
current_leader = redis_client.get(leader_key)
|
||||
|
||||
if current_leader:
|
||||
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
|
||||
|
||||
if current_leader == disconnected_sid:
|
||||
# Leader session disconnected, elect a new leader
|
||||
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
if sessions_json:
|
||||
# Get the first remaining session as new leader
|
||||
new_leader_sid = list(sessions_json.keys())[0]
|
||||
if isinstance(new_leader_sid, bytes):
|
||||
new_leader_sid = new_leader_sid.decode("utf-8")
|
||||
|
||||
redis_client.set(leader_key, new_leader_sid, ex=3600)
|
||||
|
||||
# Notify all sessions about the new leader
|
||||
broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
# No sessions left, remove leader
|
||||
redis_client.delete(leader_key)
|
||||
|
||||
|
||||
def broadcast_leader_change(workflow_id, new_leader_sid):
|
||||
"""
|
||||
Broadcast leader change to all sessions in the workflow.
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
|
||||
is_leader = sid_str == new_leader_sid
|
||||
# Emit to each session whether they are the new leader
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def get_current_leader(workflow_id):
|
||||
"""
|
||||
Get the current leader for a workflow.
|
||||
"""
|
||||
leader_key = f"workflow_leader:{workflow_id}"
|
||||
leader = redis_client.get(leader_key)
|
||||
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
|
||||
|
||||
|
||||
def broadcast_online_users(workflow_id):
|
||||
"""
|
||||
Broadcast online users to the workflow room.
|
||||
Each session is shown as a separate user (even if same person has multiple tabs).
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
users = []
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
session_info = json.loads(session_info_json)
|
||||
# Each session appears as a separate "user" in the UI
|
||||
users.append(
|
||||
{
|
||||
"user_id": session_info["user_id"],
|
||||
"username": session_info["username"],
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": session_info["sid"],
|
||||
"connected_at": session_info.get("connected_at"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort by connection time to maintain consistent order
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
# Get current leader session
|
||||
leader_sid = get_current_leader(workflow_id)
|
||||
|
||||
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
|
||||
|
||||
|
||||
@sio.on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouseMove
|
||||
2. varsAndFeaturesUpdate
|
||||
3. syncRequest(ask leader to update graph)
|
||||
4. appStateUpdate
|
||||
5. mcpServerUpdate
|
||||
|
||||
"""
|
||||
mapping = redis_client.get(f"ws_sid_map:{sid}")
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
user_id = mapping_data["user_id"]
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
sio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}
|
||||
|
||||
|
||||
@sio.on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
mapping = redis_client.get(f"ws_sid_map:{sid}")
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
|
||||
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}
|
||||
@ -5,6 +5,7 @@ from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from pydantic_core import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -21,7 +22,9 @@ from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@ -103,6 +106,7 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@ -127,6 +131,8 @@ class DraftWorkflowApi(Resource):
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("force_upload", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("memory_blocks", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
@ -143,6 +149,8 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"memory_blocks": data.get("memory_blocks"),
|
||||
"force_upload": data.get("force_upload", False),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
@ -163,6 +171,11 @@ class DraftWorkflowApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
memory_blocks_list = args.get("memory_blocks") or []
|
||||
from core.memory.entities import MemoryBlockSpec
|
||||
memory_blocks = [
|
||||
MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
@ -171,9 +184,13 @@ class DraftWorkflowApi(Resource):
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
force_upload=args.get("force_upload", False),
|
||||
memory_blocks=memory_blocks,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValidationError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@ -796,6 +813,45 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("features", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
features = args.get("features")
|
||||
|
||||
# Update draft workflow features
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
@ -985,3 +1041,105 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_ids", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowFeaturesApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/features",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeLastRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")
|
||||
|
||||
240
api/controllers/console/app/workflow_comment.py
Normal file
240
api/controllers/console/app/workflow_comment.py
Normal file
@ -0,0 +1,240 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import account_with_role_fields
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_basic_fields, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_create_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("position_x", type=float, required=True, location="json")
|
||||
parser.add_argument("position_y", type=float, required=True, location="json")
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_detail_fields)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_update_fields)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("position_x", type=float, required=False, location="json")
|
||||
parser.add_argument("position_y", type=float, required=False, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_resolve_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_create_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=args.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_update_fields)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"users": members}
|
||||
|
||||
|
||||
# Register API routes
|
||||
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
|
||||
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
api.add_resource(
|
||||
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
|
||||
)
|
||||
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
@ -353,7 +353,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@ -446,8 +446,35 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@ -497,3 +524,44 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
|
||||
34
api/controllers/console/explore/banner.py
Normal file
34
api/controllers/console/explore/banner.py
Normal file
@ -0,0 +1,34 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
banners = (
|
||||
db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all()
|
||||
)
|
||||
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
||||
@ -27,6 +27,7 @@ recommended_app_fields = {
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
||||
375
api/controllers/console/explore/trial.py
Normal file
375
api/controllers/console/explore/trial.py
Normal file
@ -0,0 +1,375 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common import fields
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
@ -2,15 +2,16 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
@ -83,3 +137,13 @@ class InstalledAppResource(Resource):
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
@ -33,6 +33,7 @@ from controllers.console.wraps import (
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -135,6 +136,17 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("avatar", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -51,6 +51,8 @@ tenant_fields = {
|
||||
"in_trial": fields.Boolean,
|
||||
"trial_end_reason": fields.String,
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
"trial_credits": fields.Integer,
|
||||
"trial_credits_used": fields.Integer,
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
|
||||
@ -19,6 +19,7 @@ from .app import (
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
file,
|
||||
@ -40,6 +41,7 @@ __all__ = [
|
||||
"annotation",
|
||||
"app",
|
||||
"audio",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"dataset",
|
||||
|
||||
124
api/controllers/service_api/app/chatflow_memory.py
Normal file
124
api/controllers/service_api/app/chatflow_memory.py
Normal file
@ -0,0 +1,124 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id,
|
||||
version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id,
|
||||
version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument('node_id', type=str | None, required=False, default=None)
|
||||
parser.add_argument('update', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {'error': 'Invalid update'}, 400
|
||||
if not workflow:
|
||||
return {'error': 'Workflow not found'}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
|
||||
if not memory_spec:
|
||||
return {'error': 'Memory not found'}, 404
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return '', 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get('id')
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(
|
||||
app_model,
|
||||
memory_id,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
api.add_resource(MemoryDeleteApi, '/memories')
|
||||
@ -18,6 +18,7 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
@ -39,6 +40,7 @@ __all__ = [
|
||||
"app",
|
||||
"audio",
|
||||
"bp",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"feature",
|
||||
|
||||
123
api/controllers/web/chatflow_memory.py
Normal file
123
api/controllers/web/chatflow_memory.py
Normal file
@ -0,0 +1,123 @@
|
||||
from flask_restx import reqparse
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(WebApiResource):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id,
|
||||
version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id,
|
||||
version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id),
|
||||
version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(WebApiResource):
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument('node_id', type=str | None, required=False, default=None)
|
||||
parser.add_argument('update', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {'error': 'Update must be a string'}, 400
|
||||
if not workflow:
|
||||
return {'error': 'Workflow not found'}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args['id']), None)
|
||||
if not memory_spec:
|
||||
return {'error': 'Memory not found'}, 404
|
||||
if not memory_spec.end_user_editable:
|
||||
return {'error': 'Memory not editable'}, 403
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return '', 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(WebApiResource):
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('id', type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get('id')
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(
|
||||
app_model,
|
||||
memory_id,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(
|
||||
app_model,
|
||||
MemoryCreatedBy(end_user_id=end_user.id)
|
||||
)
|
||||
return '', 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, '/memories')
|
||||
api.add_resource(MemoryEditApi, '/memory-edit')
|
||||
api.add_resource(MemoryDeleteApi, '/memories')
|
||||
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -20,11 +21,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_events import GraphRunSucceededEvent
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -34,6 +38,8 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -70,6 +76,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
|
||||
def run(self):
|
||||
ChatflowMemoryService.wait_for_sync_memory_completion(
|
||||
workflow=self._workflow,
|
||||
conversation_id=self.conversation.id
|
||||
)
|
||||
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@ -133,6 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
memory_blocks=self._fetch_memory_blocks(),
|
||||
)
|
||||
|
||||
# init graph
|
||||
@ -177,6 +189,31 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
try:
|
||||
self._check_app_memory_updates(variable_pool)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to check app memory updates", exc_info=e)
|
||||
|
||||
@override
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
|
||||
super()._handle_event(workflow_entry, event)
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
workflow_outputs = event.outputs
|
||||
if not workflow_outputs:
|
||||
logger.warning("Chatflow output is empty.")
|
||||
return
|
||||
assistant_message = workflow_outputs.get('answer')
|
||||
if not assistant_message:
|
||||
logger.warning("Chatflow output does not contain 'answer'.")
|
||||
return
|
||||
if not isinstance(assistant_message, str):
|
||||
logger.warning("Chatflow output 'answer' is not a string.")
|
||||
return
|
||||
try:
|
||||
self._sync_conversation_to_chatflow_tables(assistant_message)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
@ -374,3 +411,67 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# Return combined list
|
||||
return existing_variables + new_variables
|
||||
|
||||
def _fetch_memory_blocks(self) -> Mapping[str, str]:
|
||||
"""fetch all memory blocks for current app"""
|
||||
|
||||
memory_blocks_dict: MutableMapping[str, str] = {}
|
||||
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
|
||||
conversation_id = self.conversation.id
|
||||
memory_block_specs = self._workflow.memory_blocks
|
||||
# Get runtime memory values
|
||||
memories = ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by(),
|
||||
)
|
||||
|
||||
# Build memory_id -> value mapping
|
||||
for memory in memories:
|
||||
if memory.spec.scope == MemoryScope.APP:
|
||||
# App level: use memory_id directly
|
||||
memory_blocks_dict[memory.spec.id] = memory.value
|
||||
else: # NODE scope
|
||||
node_id = memory.node_id
|
||||
if not node_id:
|
||||
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
|
||||
continue
|
||||
key = f"{node_id}.{memory.spec.id}"
|
||||
memory_blocks_dict[key] = memory.value
|
||||
|
||||
return memory_blocks_dict
|
||||
|
||||
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id
|
||||
)
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=AssistantPromptMessage(content=assistant_message),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id
|
||||
)
|
||||
|
||||
def _check_app_memory_updates(self, variable_pool: VariablePool):
|
||||
is_draft = (self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER)
|
||||
|
||||
ChatflowMemoryService.update_app_memory_if_needed(
|
||||
workflow=self._workflow,
|
||||
conversation_id=self.conversation.id,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by()
|
||||
)
|
||||
|
||||
def _get_created_by(self) -> MemoryCreatedBy:
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
|
||||
else:
|
||||
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)
|
||||
|
||||
@ -56,6 +56,9 @@ class HostingConfiguration:
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@ -128,7 +131,7 @@ class HostingConfiguration:
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
@ -156,18 +159,49 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
def init_gemini(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_GEMINI_API_BASE:
|
||||
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
@ -185,6 +219,66 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_xai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_XAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_XAI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_XAI_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_deepseek(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
|
||||
@ -14,10 +14,12 @@ from core.llm_generator.prompts import (
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.memory.entities import MemoryBlock, MemoryBlockSpec
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@ -27,6 +29,7 @@ from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@ -560,3 +563,35 @@ class LLMGenerator:
|
||||
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
|
||||
)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def update_memory_block(
|
||||
tenant_id: str,
|
||||
visible_history: Sequence[tuple[str, str]],
|
||||
variable_pool: VariablePool,
|
||||
memory_block: MemoryBlock,
|
||||
memory_spec: MemoryBlockSpec
|
||||
) -> str:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=memory_spec.model.provider,
|
||||
model=memory_spec.model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
formatted_history = ""
|
||||
for sender, message in visible_history:
|
||||
formatted_history += f"{sender}: {message}\n"
|
||||
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
|
||||
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
|
||||
inputs={
|
||||
"formatted_history": formatted_history,
|
||||
"current_value": memory_block.value,
|
||||
"instruction": filled_instruction,
|
||||
}
|
||||
)
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
|
||||
model_parameters=memory_spec.model.completion_params,
|
||||
stream=False,
|
||||
)
|
||||
return llm_result.message.get_text_content()
|
||||
|
||||
@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
MEMORY_UPDATE_PROMPT = """
|
||||
Based on the following conversation history, update the memory content:
|
||||
|
||||
Conversation history:
|
||||
{{formatted_history}}
|
||||
|
||||
Current memory:
|
||||
{{current_value}}
|
||||
|
||||
Update instruction:
|
||||
{{instruction}}
|
||||
|
||||
Please output only the updated memory content, no other text like greeting:
|
||||
"""
|
||||
|
||||
119
api/core/memory/entities.py
Normal file
119
api/core/memory/entities.py
Normal file
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
|
||||
|
||||
class MemoryScope(StrEnum):
|
||||
"""Memory scope determined by node_id field"""
|
||||
APP = "app" # node_id is None
|
||||
NODE = "node" # node_id is not None
|
||||
|
||||
|
||||
class MemoryTerm(StrEnum):
|
||||
"""Memory term determined by conversation_id field"""
|
||||
SESSION = "session" # conversation_id is not None
|
||||
PERSISTENT = "persistent" # conversation_id is None
|
||||
|
||||
|
||||
class MemoryStrategy(StrEnum):
|
||||
ON_TURNS = "on_turns"
|
||||
|
||||
|
||||
class MemoryScheduleMode(StrEnum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
|
||||
|
||||
class MemoryBlockSpec(BaseModel):
|
||||
"""Memory block specification for workflow configuration"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
description="Unique identifier for the memory block",
|
||||
)
|
||||
name: str = Field(description="Display name of the memory block")
|
||||
description: str = Field(default="", description="Description of the memory block")
|
||||
template: str = Field(description="Initial template content for the memory")
|
||||
instruction: str = Field(description="Instructions for updating the memory")
|
||||
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
|
||||
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
|
||||
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
|
||||
update_turns: int = Field(gt=0, description="Number of turns between updates")
|
||||
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
|
||||
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
|
||||
model: ModelConfig = Field(description="Model configuration for memory updates")
|
||||
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
||||
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
||||
|
||||
|
||||
class MemoryCreatedBy(BaseModel):
|
||||
end_user_id: str | None = None
|
||||
account_id: str | None = None
|
||||
|
||||
|
||||
class MemoryBlock(BaseModel):
|
||||
"""Runtime memory block instance
|
||||
|
||||
Design Rules:
|
||||
- app_id = None: Global memory (future feature, not implemented yet)
|
||||
- app_id = str: App-specific memory
|
||||
- conversation_id = None: Persistent memory (cross-conversation)
|
||||
- conversation_id = str: Session memory (conversation-specific)
|
||||
- node_id = None: App-level scope
|
||||
- node_id = str: Node-level scope
|
||||
|
||||
These rules implicitly determine scope and term without redundant storage.
|
||||
"""
|
||||
spec: MemoryBlockSpec
|
||||
tenant_id: str
|
||||
value: str
|
||||
app_id: str
|
||||
conversation_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
edited_by_user: bool = False
|
||||
created_by: MemoryCreatedBy
|
||||
version: int = Field(description="Memory block version number")
|
||||
|
||||
|
||||
class MemoryValueData(BaseModel):
|
||||
value: str
|
||||
edited_by_user: bool = False
|
||||
|
||||
|
||||
class ChatflowConversationMetadata(BaseModel):
|
||||
"""Metadata for chatflow conversation with visible message count"""
|
||||
type: str = "mutable_visible_window"
|
||||
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
|
||||
|
||||
|
||||
class MemoryBlockWithConversation(MemoryBlock):
|
||||
"""MemoryBlock with optional conversation metadata for session memories"""
|
||||
conversation_metadata: ChatflowConversationMetadata = Field(
|
||||
description="Conversation metadata, only present for session memories"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_memory_block(
|
||||
cls,
|
||||
memory_block: MemoryBlock,
|
||||
conversation_metadata: ChatflowConversationMetadata
|
||||
) -> MemoryBlockWithConversation:
|
||||
"""Create MemoryBlockWithConversation from MemoryBlock"""
|
||||
return cls(
|
||||
spec=memory_block.spec,
|
||||
tenant_id=memory_block.tenant_id,
|
||||
value=memory_block.value,
|
||||
app_id=memory_block.app_id,
|
||||
conversation_id=memory_block.conversation_id,
|
||||
node_id=memory_block.node_id,
|
||||
edited_by_user=memory_block.edited_by_user,
|
||||
created_by=memory_block.created_by,
|
||||
version=memory_block.version,
|
||||
conversation_metadata=conversation_metadata
|
||||
)
|
||||
6
api/core/memory/errors.py
Normal file
6
api/core/memory/errors.py
Normal file
@ -0,0 +1,6 @@
|
||||
class MemorySyncTimeoutError(Exception):
|
||||
def __init__(self, app_id: str, conversation_id: str):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.message = "Memory synchronization timeout after 50 seconds"
|
||||
super().__init__(self.message)
|
||||
@ -45,6 +45,12 @@ class MemoryConfig(BaseModel):
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
mode: Literal["linear", "block"] | None = "linear"
|
||||
block_id: list[str] | None = None
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
@property
|
||||
def is_block_mode(self) -> bool:
|
||||
return self.mode == "block" and bool(self.block_id)
|
||||
|
||||
@ -618,9 +618,9 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
if quota.quota_type not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
|
||||
new_provider_record = Provider(
|
||||
@ -628,8 +628,8 @@ class ProviderManager:
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
quota_type=quota.quota_type,
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
@ -642,7 +642,7 @@ class ProviderManager:
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.quota_type == quota.quota_type,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
@ -652,7 +652,7 @@ class ProviderManager:
|
||||
existed_provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@ -912,6 +912,22 @@ class ProviderManager:
|
||||
provider_record
|
||||
)
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
paid_pool = None
|
||||
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
if provider_quota.quota_type == ProviderQuotaType.FREE:
|
||||
@ -932,16 +948,36 @@ class ProviderManager:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
else:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
|
||||
@ -203,6 +203,48 @@ class ArrayFileSegment(ArraySegment):
|
||||
return ""
|
||||
|
||||
|
||||
class VersionedMemoryValue(BaseModel):
|
||||
current_value: str = None # type: ignore
|
||||
versions: Mapping[str, str] = {}
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
def add_version(
|
||||
self,
|
||||
new_value: str,
|
||||
version_name: str | None = None
|
||||
) -> "VersionedMemoryValue":
|
||||
if version_name is None:
|
||||
version_name = str(len(self.versions) + 1)
|
||||
if version_name in self.versions:
|
||||
raise ValueError(f"Version '{version_name}' already exists.")
|
||||
self.current_value = new_value
|
||||
return VersionedMemoryValue(
|
||||
current_value=new_value,
|
||||
versions={
|
||||
version_name: new_value,
|
||||
**self.versions,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VersionedMemorySegment(Segment):
|
||||
value_type: SegmentType = SegmentType.VERSIONED_MEMORY
|
||||
value: VersionedMemoryValue = None # type: ignore
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.value.current_value
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return self.value.current_value
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.current_value
|
||||
|
||||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool] = None # type: ignore
|
||||
@ -248,6 +290,7 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[VersionedMemorySegment, Tag(SegmentType.VERSIONED_MEMORY)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -41,6 +41,8 @@ class SegmentType(StrEnum):
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
|
||||
VERSIONED_MEMORY = "versioned_memory"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
GROUP = "group"
|
||||
|
||||
@ -22,6 +22,7 @@ from .segments import (
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
VersionedMemorySegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from .types import SegmentType
|
||||
@ -106,6 +107,10 @@ class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class VersionedMemoryVariable(VersionedMemorySegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
@ -161,6 +166,7 @@ VariableUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
| Annotated[VersionedMemoryVariable, Tag(SegmentType.VERSIONED_MEMORY)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
@ -8,11 +8,12 @@ from pydantic import BaseModel, Field
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.variables.segments import FileSegment, ObjectSegment, VersionedMemoryValue
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion, VersionedMemoryVariable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
MEMORY_BLOCK_VARIABLE_NODE_ID,
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
@ -55,6 +56,10 @@ class VariablePool(BaseModel):
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
memory_blocks: Mapping[str, str] = Field(
|
||||
description="Memory blocks.",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
@ -75,6 +80,18 @@ class VariablePool(BaseModel):
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
# Add memory blocks to the variable pool
|
||||
for memory_id, memory_value in self.memory_blocks.items():
|
||||
self.add(
|
||||
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id],
|
||||
VersionedMemoryVariable(
|
||||
value=VersionedMemoryValue(
|
||||
current_value=memory_value,
|
||||
versions={"1": memory_value},
|
||||
),
|
||||
name=memory_id,
|
||||
)
|
||||
)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@ -6,11 +6,15 @@ import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
@ -71,6 +75,8 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from models import UserFrom, Workflow
|
||||
from models.engine import db
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@ -315,6 +321,11 @@ class LLMNode(Node):
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
try:
|
||||
self._handle_chatflow_memory(result_text, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@ -1184,6 +1195,79 @@ class LLMNode(Node):
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
|
||||
if not self._node_data.memory or self._node_data.memory.mode != "block":
|
||||
return
|
||||
|
||||
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
|
||||
if not conversation_id_segment:
|
||||
raise ValueError("Conversation ID not found in variable pool.")
|
||||
conversation_id = conversation_id_segment.text
|
||||
|
||||
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not user_query_segment:
|
||||
raise ValueError("User query not found in variable pool.")
|
||||
user_query = user_query_segment.text
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(UserPromptMessage(content=user_query)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(AssistantPromptMessage(content=llm_output)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
|
||||
memory_config = self._node_data.memory
|
||||
if not memory_config:
|
||||
return
|
||||
block_ids = memory_config.block_id
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == self.tenant_id,
|
||||
Workflow.app_id == self.app_id
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found.")
|
||||
memory_blocks = workflow.memory_blocks
|
||||
|
||||
for block_id in block_ids:
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
|
||||
|
||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
ChatflowMemoryService.update_node_memory_if_needed(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
node_id=self.id,
|
||||
conversation_id=conversation_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_user_from_context()
|
||||
)
|
||||
|
||||
def _get_user_from_context(self) -> MemoryCreatedBy:
|
||||
if self.user_from == UserFrom.ACCOUNT:
|
||||
return MemoryCreatedBy(account_id=self.user_id)
|
||||
else:
|
||||
return MemoryCreatedBy(end_user_id=self.user_id)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
|
||||
@ -38,14 +38,16 @@ elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
else
|
||||
if [[ "${DEBUG}" == "true" ]]; then
|
||||
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
|
||||
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
|
||||
export PORT=${DIFY_PORT:-5001}
|
||||
exec python -m app
|
||||
else
|
||||
exec gunicorn \
|
||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
|
||||
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||
app:app
|
||||
app:socketio_app
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
@ -133,22 +133,38 @@ def handle(sender: Message, **kwargs):
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
|
||||
3
api/extensions/ext_socketio.py
Normal file
3
api/extensions/ext_socketio.py
Normal file
@ -0,0 +1,3 @@
|
||||
import socketio
|
||||
|
||||
sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*")
|
||||
@ -21,6 +21,8 @@ from core.variables.segments import (
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
VersionedMemorySegment,
|
||||
VersionedMemoryValue,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
@ -39,6 +41,7 @@ from core.variables.variables import (
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VersionedMemoryVariable,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
@ -69,6 +72,7 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
NoneSegment: NoneVariable,
|
||||
ObjectSegment: ObjectVariable,
|
||||
StringSegment: StringVariable,
|
||||
VersionedMemorySegment: VersionedMemoryVariable
|
||||
}
|
||||
|
||||
|
||||
@ -193,6 +197,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.FILE: FileSegment,
|
||||
SegmentType.BOOLEAN: BooleanSegment,
|
||||
SegmentType.OBJECT: ObjectSegment,
|
||||
SegmentType.VERSIONED_MEMORY: VersionedMemorySegment,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY: ArrayAnySegment,
|
||||
SegmentType.ARRAY_STRING: ArrayStringSegment,
|
||||
@ -259,6 +264,12 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
|
||||
|
||||
if segment_type == SegmentType.VERSIONED_MEMORY:
|
||||
return VersionedMemorySegment(
|
||||
value_type=segment_type,
|
||||
value=VersionedMemoryValue.model_validate(value)
|
||||
)
|
||||
|
||||
inferred_type = SegmentType.infer_segment_type(value)
|
||||
# Type compatibility checking
|
||||
if inferred_type is None:
|
||||
|
||||
17
api/fields/online_user_fields.py
Normal file
17
api/fields/online_user_fields.py
Normal file
@ -0,0 +1,17 @@
|
||||
from flask_restx import fields
|
||||
|
||||
online_user_partial_fields = {
|
||||
"user_id": fields.String,
|
||||
"username": fields.String,
|
||||
"avatar": fields.String,
|
||||
"sid": fields.String,
|
||||
}
|
||||
|
||||
workflow_online_users_fields = {
|
||||
"workflow_id": fields.String,
|
||||
"users": fields.List(fields.Nested(online_user_partial_fields)),
|
||||
}
|
||||
|
||||
online_user_list_fields = {
|
||||
"data": fields.List(fields.Nested(workflow_online_users_fields)),
|
||||
}
|
||||
96
api/fields/workflow_comment_fields.py
Normal file
96
api/fields/workflow_comment_fields.py
Normal file
@ -0,0 +1,96 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
# basic account fields for comments
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
}
|
||||
|
||||
# Comment mention fields
|
||||
workflow_comment_mention_fields = {
|
||||
"mentioned_user_id": fields.String,
|
||||
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_id": fields.String,
|
||||
}
|
||||
|
||||
# Comment reply fields
|
||||
workflow_comment_reply_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Basic comment fields (for list views)
|
||||
workflow_comment_basic_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_count": fields.Integer,
|
||||
"mention_count": fields.Integer,
|
||||
"participants": fields.List(fields.Nested(account_fields)),
|
||||
}
|
||||
|
||||
# Detailed comment fields (for single comment view)
|
||||
workflow_comment_detail_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
|
||||
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
|
||||
}
|
||||
|
||||
# Comment creation response fields (simplified)
|
||||
workflow_comment_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment update response fields (simplified)
|
||||
workflow_comment_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment resolve response fields
|
||||
workflow_comment_resolve_fields = {
|
||||
"id": fields.String,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
}
|
||||
|
||||
# Reply creation response fields (simplified)
|
||||
workflow_comment_reply_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Reply update response fields
|
||||
workflow_comment_reply_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
@ -49,6 +49,23 @@ conversation_variable_fields = {
|
||||
"description": fields.String,
|
||||
}
|
||||
|
||||
memory_block_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"template": fields.String,
|
||||
"instruction": fields.String,
|
||||
"scope": fields.String,
|
||||
"term": fields.String,
|
||||
"strategy": fields.String,
|
||||
"update_turns": fields.Integer,
|
||||
"preserved_turns": fields.Integer,
|
||||
"schedule_mode": fields.String,
|
||||
"model": fields.Raw,
|
||||
"end_user_visible": fields.Boolean,
|
||||
"end_user_editable": fields.Boolean,
|
||||
}
|
||||
|
||||
pipeline_variable_fields = {
|
||||
"label": fields.String,
|
||||
"variable": fields.String,
|
||||
@ -81,6 +98,7 @@ workflow_fields = {
|
||||
"tool_published": fields.Boolean,
|
||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||
"memory_blocks": fields.List(fields.Nested(memory_block_fields)),
|
||||
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,90 @@
|
||||
"""Add workflow comments table
|
||||
|
||||
Revision ID: 227822d22895
|
||||
Revises: 68519ad5cd18
|
||||
Create Date: 2025-08-22 17:26:15.255980
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '227822d22895'
|
||||
down_revision = '68519ad5cd18'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('workflow_comments',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('position_x', sa.Float(), nullable=False),
|
||||
sa.Column('position_y', sa.Float(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('resolved_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
|
||||
|
||||
op.create_table('workflow_comment_replies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
|
||||
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
|
||||
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
|
||||
|
||||
op.create_table('workflow_comment_mentions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
|
||||
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
|
||||
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
|
||||
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
|
||||
batch_op.drop_index('comment_mentions_user_idx')
|
||||
batch_op.drop_index('comment_mentions_reply_idx')
|
||||
batch_op.drop_index('comment_mentions_comment_idx')
|
||||
|
||||
op.drop_table('workflow_comment_mentions')
|
||||
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
|
||||
batch_op.drop_index('comment_replies_created_at_idx')
|
||||
batch_op.drop_index('comment_replies_comment_idx')
|
||||
|
||||
op.drop_table('workflow_comment_replies')
|
||||
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_comments_created_at_idx')
|
||||
batch_op.drop_index('workflow_comments_app_idx')
|
||||
|
||||
op.drop_table('workflow_comments')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,79 @@
|
||||
"""add table explore banner and trial
|
||||
|
||||
Revision ID: 1b435d90db42
|
||||
Revises: cf7c38a32b2d
|
||||
Create Date: 2025-09-19 14:42:58.416649
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1b435d90db42'
|
||||
down_revision = 'cf7c38a32b2d'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_trial_app_records',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('count', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
|
||||
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
|
||||
)
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
|
||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
op.create_table('exporle_banners',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=False),
|
||||
sa.Column('link', sa.String(length=255), nullable=False),
|
||||
sa.Column('sort', sa.Integer(), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||
)
|
||||
op.create_table('trial_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('trial_limit', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
|
||||
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
|
||||
)
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.drop_index('trial_app_tenant_id_idx')
|
||||
batch_op.drop_index('trial_app_app_id_idx')
|
||||
|
||||
op.drop_table('trial_apps')
|
||||
op.drop_table('exporle_banners')
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.drop_index('account_trial_app_record_app_id_idx')
|
||||
batch_op.drop_index('account_trial_app_record_account_id_idx')
|
||||
|
||||
op.drop_table('account_trial_app_records')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,104 @@
|
||||
"""add table credit pool
|
||||
|
||||
Revision ID: 58a70d22fdbd
|
||||
Revises: 68519ad5cd18
|
||||
Create Date: 2025-09-25 15:20:40.367078
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '58a70d22fdbd'
|
||||
down_revision = '68519ad5cd18'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
|
||||
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
# Data migration: Move quota data from providers to tenant_credit_pools
|
||||
migrate_quota_data()
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
|
||||
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
|
||||
|
||||
op.drop_table('tenant_credit_pools')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def migrate_quota_data():
|
||||
"""
|
||||
Migrate quota data from providers table to tenant_credit_pools table
|
||||
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
|
||||
"""
|
||||
# Create connection
|
||||
bind = op.get_bind()
|
||||
|
||||
# Define quota type mappings
|
||||
quota_type_mappings = ['trial', 'paid']
|
||||
|
||||
for quota_type in quota_type_mappings:
|
||||
# Query providers that match the criteria
|
||||
select_sql = sa.text("""
|
||||
SELECT tenant_id, quota_limit, quota_used
|
||||
FROM providers
|
||||
WHERE quota_type = :quota_type
|
||||
AND provider_name = 'openai'
|
||||
AND provider_type = 'system'
|
||||
AND quota_limit IS NOT NULL
|
||||
""")
|
||||
|
||||
result = bind.execute(select_sql, {"quota_type": quota_type})
|
||||
providers_data = result.fetchall()
|
||||
|
||||
# Insert data into tenant_credit_pools
|
||||
for provider_data in providers_data:
|
||||
tenant_id, quota_limit, quota_used = provider_data
|
||||
|
||||
# Check if credit pool already exists for this tenant and pool type
|
||||
check_sql = sa.text("""
|
||||
SELECT COUNT(*)
|
||||
FROM tenant_credit_pools
|
||||
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
|
||||
""")
|
||||
|
||||
existing_count = bind.execute(check_sql, {
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": quota_type
|
||||
}).scalar()
|
||||
|
||||
if existing_count == 0:
|
||||
# Insert new credit pool record
|
||||
insert_sql = sa.text("""
|
||||
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
|
||||
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
""")
|
||||
|
||||
bind.execute(insert_sql, {
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": quota_type,
|
||||
"quota_limit": quota_limit or 0,
|
||||
"quota_used": quota_used or 0
|
||||
})
|
||||
@ -0,0 +1,104 @@
|
||||
"""add_chatflow_memory_tables
|
||||
|
||||
Revision ID: d00b2b40ea3e
|
||||
Revises: 68519ad5cd18
|
||||
Create Date: 2025-10-11 15:29:20.244675
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd00b2b40ea3e'
|
||||
down_revision = '68519ad5cd18'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chatflow_conversations',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('node_id', sa.Text(), nullable=True),
|
||||
sa.Column('original_conversation_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('conversation_metadata', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='chatflow_conversations_pkey')
|
||||
)
|
||||
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
|
||||
batch_op.create_index('chatflow_conversations_original_conversation_id_idx', ['tenant_id', 'app_id', 'node_id', 'original_conversation_id'], unique=False)
|
||||
|
||||
op.create_table('chatflow_memory_variables',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('node_id', sa.Text(), nullable=True),
|
||||
sa.Column('memory_id', sa.Text(), nullable=False),
|
||||
sa.Column('value', sa.Text(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('scope', sa.String(length=10), nullable=False),
|
||||
sa.Column('term', sa.String(length=20), nullable=False),
|
||||
sa.Column('version', sa.Integer(), nullable=False),
|
||||
sa.Column('created_by_role', sa.String(length=20), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='chatflow_memory_variables_pkey')
|
||||
)
|
||||
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
|
||||
batch_op.create_index('chatflow_memory_variables_memory_id_idx', ['tenant_id', 'app_id', 'node_id', 'memory_id'], unique=False)
|
||||
|
||||
op.create_table('chatflow_messages',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('index', sa.Integer(), nullable=False),
|
||||
sa.Column('version', sa.Integer(), nullable=False),
|
||||
sa.Column('data', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='chatflow_messages_pkey')
|
||||
)
|
||||
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
|
||||
batch_op.create_index('chatflow_messages_version_idx', ['conversation_id', 'index', 'version'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('avatar_url',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=True)
|
||||
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('avatar_url',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=True)
|
||||
|
||||
with op.batch_alter_table('chatflow_messages', schema=None) as batch_op:
|
||||
batch_op.drop_index('chatflow_messages_version_idx')
|
||||
|
||||
op.drop_table('chatflow_messages')
|
||||
with op.batch_alter_table('chatflow_memory_variables', schema=None) as batch_op:
|
||||
batch_op.drop_index('chatflow_memory_variables_memory_id_idx')
|
||||
|
||||
op.drop_table('chatflow_memory_variables')
|
||||
with op.batch_alter_table('chatflow_conversations', schema=None) as batch_op:
|
||||
batch_op.drop_index('chatflow_conversations_original_conversation_id_idx')
|
||||
|
||||
op.drop_table('chatflow_conversations')
|
||||
# ### end Alembic commands ###
|
||||
@ -9,6 +9,12 @@ from .account import (
|
||||
TenantStatus,
|
||||
)
|
||||
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
|
||||
from .comment import (
|
||||
WorkflowComment,
|
||||
WorkflowCommentMention,
|
||||
WorkflowCommentReply,
|
||||
)
|
||||
from .dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
@ -28,6 +34,7 @@ from .dataset import (
|
||||
)
|
||||
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .model import (
|
||||
AccountTrialAppRecord,
|
||||
ApiRequest,
|
||||
ApiToken,
|
||||
App,
|
||||
@ -40,6 +47,7 @@ from .model import (
|
||||
DatasetRetrieverResource,
|
||||
DifySetup,
|
||||
EndUser,
|
||||
ExporleBanner,
|
||||
IconType,
|
||||
InstalledApp,
|
||||
Message,
|
||||
@ -53,7 +61,9 @@ from .model import (
|
||||
Site,
|
||||
Tag,
|
||||
TagBinding,
|
||||
TenantCreditPool,
|
||||
TraceAppConfig,
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -98,6 +108,7 @@ __all__ = [
|
||||
"Account",
|
||||
"AccountIntegrate",
|
||||
"AccountStatus",
|
||||
"AccountTrialAppRecord",
|
||||
"ApiRequest",
|
||||
"ApiToken",
|
||||
"ApiToolProvider",
|
||||
@ -111,6 +122,9 @@ __all__ = [
|
||||
"BuiltinToolProvider",
|
||||
"CeleryTask",
|
||||
"CeleryTaskSet",
|
||||
"ChatflowConversation",
|
||||
"ChatflowMemoryVariable",
|
||||
"ChatflowMessage",
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatorUserRole",
|
||||
@ -131,6 +145,7 @@ __all__ = [
|
||||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"ExporleBanner",
|
||||
"ExternalKnowledgeApis",
|
||||
"ExternalKnowledgeBindings",
|
||||
"IconType",
|
||||
@ -159,6 +174,7 @@ __all__ = [
|
||||
"Tenant",
|
||||
"TenantAccountJoin",
|
||||
"TenantAccountRole",
|
||||
"TenantCreditPool",
|
||||
"TenantDefaultModel",
|
||||
"TenantPreferredModelProvider",
|
||||
"TenantStatus",
|
||||
@ -168,12 +184,16 @@ __all__ = [
|
||||
"ToolLabelBinding",
|
||||
"ToolModelInvoke",
|
||||
"TraceAppConfig",
|
||||
"TrialApp",
|
||||
"UploadFile",
|
||||
"UserFrom",
|
||||
"Whitelist",
|
||||
"Workflow",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowComment",
|
||||
"WorkflowCommentMention",
|
||||
"WorkflowCommentReply",
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
|
||||
76
api/models/chatflow_memory.py
Normal file
76
api/models/chatflow_memory.py
Normal file
@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ChatflowMemoryVariable(Base):
|
||||
__tablename__ = "chatflow_memory_variables"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="chatflow_memory_variables_pkey"),
|
||||
sa.Index("chatflow_memory_variables_memory_id_idx", "tenant_id", "app_id", "node_id", "memory_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
memory_id: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
value: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
scope: Mapped[str] = mapped_column(sa.String(10), nullable=False) # 'app' or 'node'
|
||||
term: Mapped[str] = mapped_column(sa.String(20), nullable=False) # 'session' or 'persistent'
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=1)
|
||||
created_by_role: Mapped[str] = mapped_column(sa.String(20)) # 'end_user' or 'account`
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ChatflowConversation(Base):
|
||||
__tablename__ = "chatflow_conversations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="chatflow_conversations_pkey"),
|
||||
sa.Index(
|
||||
"chatflow_conversations_original_conversation_id_idx",
|
||||
"tenant_id", "app_id", "node_id", "original_conversation_id"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
original_conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
conversation_metadata: Mapped[str] = mapped_column(sa.Text, nullable=False) # JSON
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
|
||||
class ChatflowMessage(Base):
|
||||
__tablename__ = "chatflow_messages"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="chatflow_messages_pkey"),
|
||||
sa.Index("chatflow_messages_version_idx", "conversation_id", "index", "version"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
index: Mapped[int] = mapped_column(sa.Integer, nullable=False) # This index starts from 0
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Serialized PromptMessage JSON
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
189
api/models/comment.py
Normal file
189
api/models/comment.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
since an app has only one draft workflow at a time and comments should persist
|
||||
across workflow version changes.
|
||||
|
||||
Attributes:
|
||||
id: Comment ID
|
||||
tenant_id: Workspace ID
|
||||
app_id: App ID (primary association, comments belong to apps)
|
||||
position_x: X coordinate on canvas
|
||||
position_y: Y coordinate on canvas
|
||||
content: Comment content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
updated_at: Last update time
|
||||
resolved: Whether comment is resolved
|
||||
resolved_at: Resolution time
|
||||
resolved_by: Resolver account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
@property
|
||||
def resolved_by_account(self):
|
||||
"""Get resolver account."""
|
||||
if self.resolved_by:
|
||||
return db.session.get(Account, self.resolved_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def reply_count(self):
|
||||
"""Get reply count."""
|
||||
return len(self.replies)
|
||||
|
||||
@property
|
||||
def mention_count(self):
|
||||
"""Get mention count."""
|
||||
return len(self.mentions)
|
||||
|
||||
@property
|
||||
def participants(self):
|
||||
"""Get all participants (creator + repliers + mentioned users)."""
|
||||
participant_ids = set()
|
||||
|
||||
# Add comment creator
|
||||
participant_ids.add(self.created_by)
|
||||
|
||||
# Add reply creators
|
||||
participant_ids.update(reply.created_by for reply in self.replies)
|
||||
|
||||
# Add mentioned users
|
||||
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
|
||||
|
||||
# Get account objects
|
||||
participants = []
|
||||
for user_id in participant_ids:
|
||||
account = db.session.get(Account, user_id)
|
||||
if account:
|
||||
participants.append(account)
|
||||
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
id: Reply ID
|
||||
comment_id: Parent comment ID
|
||||
content: Reply content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
cannot access workflow canvas and commenting features.
|
||||
|
||||
Attributes:
|
||||
id: Mention ID
|
||||
comment_id: Parent comment ID
|
||||
mentioned_user_id: Mentioned account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[Optional[str]] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
"""Get mentioned account."""
|
||||
return db.session.get(Account, self.mentioned_user_id)
|
||||
@ -23,6 +23,7 @@ class DraftVariableType(StrEnum):
|
||||
NODE = "node"
|
||||
SYS = "sys"
|
||||
CONVERSATION = "conversation"
|
||||
MEMORY_BLOCK = "memory_block"
|
||||
|
||||
|
||||
class MessageStatus(StrEnum):
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
@ -581,6 +581,63 @@ class InstalledApp(Base):
|
||||
return tenant
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
sa.Index("trial_app_app_id_idx", "app_id"),
|
||||
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
|
||||
|
||||
class ExporleBanner(Base):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
content = mapped_column(sa.JSON, nullable=False)
|
||||
link = mapped_column(String(255), nullable=False)
|
||||
sort = mapped_column(sa.Integer, nullable=False)
|
||||
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
@ -1944,3 +2001,29 @@ class TraceAppConfig(Base):
|
||||
"created_at": str(self.created_at) if self.created_at else None,
|
||||
"updated_at": str(self.updated_at) if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class TenantCreditPool(Base):
|
||||
__tablename__ = "tenant_credit_pools"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
|
||||
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_credits(self) -> int:
|
||||
return max(0, self.quota_limit - self.quota_used)
|
||||
|
||||
def has_sufficient_credits(self, required_credits: int) -> bool:
|
||||
return self.remaining_credits >= required_credits
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@ -11,9 +12,15 @@ from sqlalchemy import DateTime, Select, exists, orm, select
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
from core.memory.entities import MemoryBlockSpec
|
||||
from core.variables import utils as variable_utils
|
||||
from core.variables.segments import VersionedMemoryValue
|
||||
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
MEMORY_BLOCK_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
@ -149,6 +156,9 @@ class Workflow(Base):
|
||||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_memory_blocks: Mapped[str] = mapped_column(
|
||||
"memory_blocks", sa.Text, nullable=False, server_default="[]"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
@ -166,6 +176,7 @@ class Workflow(Base):
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: list[dict],
|
||||
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> "Workflow":
|
||||
@ -181,6 +192,7 @@ class Workflow(Base):
|
||||
workflow.environment_variables = environment_variables or []
|
||||
workflow.conversation_variables = conversation_variables or []
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = naive_utc_now()
|
||||
@ -335,7 +347,7 @@ class Workflow(Base):
|
||||
|
||||
:return: hash
|
||||
"""
|
||||
entity = {"graph": self.graph_dict, "features": self.features_dict}
|
||||
entity = {"graph": self.graph_dict}
|
||||
|
||||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||
|
||||
@ -440,7 +452,7 @@ class Workflow(Base):
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"rag_pipeline_variables": self.rag_pipeline_variables,
|
||||
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
|
||||
}
|
||||
return result
|
||||
|
||||
@ -478,6 +490,27 @@ class Workflow(Base):
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
|
||||
"""Memory blocks configuration stored in database"""
|
||||
if self._memory_blocks is None or self._memory_blocks == "":
|
||||
self._memory_blocks = "[]"
|
||||
|
||||
memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
|
||||
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
|
||||
return results
|
||||
|
||||
@memory_blocks.setter
|
||||
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
|
||||
if not value:
|
||||
self._memory_blocks = "[]"
|
||||
return
|
||||
|
||||
self._memory_blocks = json.dumps(
|
||||
[block.model_dump() for block in value],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
@ -1489,6 +1522,31 @@ class WorkflowDraftVariable(Base):
|
||||
variable.editable = editable
|
||||
return variable
|
||||
|
||||
@staticmethod
|
||||
def new_memory_block_variable(
|
||||
*,
|
||||
app_id: str,
|
||||
node_id: str | None = None,
|
||||
memory_id: str,
|
||||
name: str,
|
||||
value: VersionedMemoryValue,
|
||||
description: str = "",
|
||||
) -> "WorkflowDraftVariable":
|
||||
"""Create a new memory block draft variable."""
|
||||
return WorkflowDraftVariable(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app_id,
|
||||
node_id=MEMORY_BLOCK_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
value=value.model_dump_json(),
|
||||
description=description,
|
||||
selector=[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id] if node_id is None else
|
||||
[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_id, node_id],
|
||||
value_type=SegmentType.VERSIONED_MEMORY,
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def edited(self):
|
||||
return self.last_edited_at is not None
|
||||
|
||||
@ -20,6 +20,7 @@ dependencies = [
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gevent-websocket~=0.10.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-python-client==2.90.0",
|
||||
@ -68,6 +69,7 @@ dependencies = [
|
||||
"pypdfium2==4.30.0",
|
||||
"python-docx~=1.1.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"python-socketio~=5.13.0",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=6.1.0",
|
||||
@ -86,6 +88,7 @@ dependencies = [
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.0",
|
||||
"packaging~=23.2",
|
||||
"gevent-websocket>=0.10.1",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
||||
@ -995,6 +995,11 @@ class TenantService:
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db.session.commit()
|
||||
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.create_default_pool(tenant.id)
|
||||
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
|
||||
237
api/services/chatflow_history_service.py
Normal file
237
api/services/chatflow_history_service.py
Normal file
@ -0,0 +1,237 @@
|
||||
import json
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from sqlalchemy import Row, Select, and_, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.entities import ChatflowConversationMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.chatflow_memory import ChatflowConversation, ChatflowMessage
|
||||
|
||||
|
||||
class ChatflowHistoryService:
|
||||
|
||||
@staticmethod
|
||||
def get_visible_chat_history(
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None,
|
||||
max_visible_count: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
with Session(db.engine) as session:
|
||||
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
|
||||
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
|
||||
)
|
||||
|
||||
if not chatflow_conv:
|
||||
return []
|
||||
|
||||
metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
|
||||
visible_count: int = max_visible_count or metadata.visible_count
|
||||
|
||||
stmt = select(ChatflowMessage).where(
|
||||
ChatflowMessage.conversation_id == chatflow_conv.id
|
||||
).order_by(ChatflowMessage.index.asc(), ChatflowMessage.version.desc())
|
||||
raw_messages: Sequence[Row[tuple[ChatflowMessage]]] = session.execute(stmt).all()
|
||||
sorted_messages = ChatflowHistoryService._filter_latest_messages(
|
||||
[it[0] for it in raw_messages]
|
||||
)
|
||||
visible_count = min(visible_count, len(sorted_messages))
|
||||
visible_messages = sorted_messages[-visible_count:]
|
||||
return [PromptMessage.model_validate_json(it.data) for it in visible_messages]
|
||||
|
||||
@staticmethod
|
||||
def save_message(
|
||||
prompt_message: PromptMessage,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None
|
||||
) -> None:
|
||||
with Session(db.engine) as session:
|
||||
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
|
||||
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
|
||||
)
|
||||
|
||||
# Get next index
|
||||
max_index = session.execute(
|
||||
select(func.max(ChatflowMessage.index)).where(
|
||||
ChatflowMessage.conversation_id == chatflow_conv.id
|
||||
)
|
||||
).scalar() or -1
|
||||
next_index = max_index + 1
|
||||
|
||||
# Save new message to append-only table
|
||||
new_message = ChatflowMessage(
|
||||
conversation_id=chatflow_conv.id,
|
||||
index=next_index,
|
||||
version=1,
|
||||
data=json.dumps(prompt_message)
|
||||
)
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
|
||||
# 添加:每次保存消息后简单增长visible_count
|
||||
current_metadata = ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
|
||||
new_visible_count = current_metadata.visible_count + 1
|
||||
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
|
||||
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
|
||||
|
||||
@staticmethod
|
||||
def save_app_message(
|
||||
prompt_message: PromptMessage,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
"""Save PromptMessage to app-level chatflow conversation."""
|
||||
ChatflowHistoryService.save_message(
|
||||
prompt_message=prompt_message,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
node_id=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_node_message(
|
||||
prompt_message: PromptMessage,
|
||||
node_id: str,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
ChatflowHistoryService.save_message(
|
||||
prompt_message=prompt_message,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_visible_count(
|
||||
conversation_id: str,
|
||||
node_id: Optional[str],
|
||||
new_visible_count: int,
|
||||
app_id: str,
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
with Session(db.engine) as session:
|
||||
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
|
||||
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=True
|
||||
)
|
||||
|
||||
# Only update visible_count in metadata, do not delete any data
|
||||
new_metadata = ChatflowConversationMetadata(visible_count=new_visible_count)
|
||||
chatflow_conv.conversation_metadata = new_metadata.model_dump_json()
|
||||
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_metadata(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: Optional[str]
|
||||
) -> ChatflowConversationMetadata:
|
||||
with Session(db.engine) as session:
|
||||
chatflow_conv = ChatflowHistoryService._get_or_create_chatflow_conversation(
|
||||
session, conversation_id, app_id, tenant_id, node_id, create_if_missing=False
|
||||
)
|
||||
if not chatflow_conv:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
return ChatflowConversationMetadata.model_validate_json(chatflow_conv.conversation_metadata)
|
||||
|
||||
@staticmethod
|
||||
def _filter_latest_messages(raw_messages: Sequence[ChatflowMessage]) -> Sequence[ChatflowMessage]:
|
||||
index_to_message: MutableMapping[int, ChatflowMessage] = {}
|
||||
for msg in raw_messages:
|
||||
index = msg.index
|
||||
if index not in index_to_message or msg.version > index_to_message[index].version:
|
||||
index_to_message[index] = msg
|
||||
|
||||
sorted_messages = sorted(index_to_message.values(), key=lambda m: m.index)
|
||||
return sorted_messages
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def _get_or_create_chatflow_conversation(
|
||||
session: Session,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None,
|
||||
create_if_missing: Literal[True] = True
|
||||
) -> ChatflowConversation: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def _get_or_create_chatflow_conversation(
|
||||
session: Session,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None,
|
||||
create_if_missing: Literal[False] = False
|
||||
) -> Optional[ChatflowConversation]: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def _get_or_create_chatflow_conversation(
|
||||
session: Session,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None,
|
||||
create_if_missing: bool = False
|
||||
) -> Optional[ChatflowConversation]: ...
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_chatflow_conversation(
|
||||
session: Session,
|
||||
conversation_id: str,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: Optional[str] = None,
|
||||
create_if_missing: bool = False
|
||||
) -> Optional[ChatflowConversation]:
|
||||
"""Get existing chatflow conversation or optionally create new one"""
|
||||
stmt: Select[tuple[ChatflowConversation]] = select(ChatflowConversation).where(
|
||||
and_(
|
||||
ChatflowConversation.original_conversation_id == conversation_id,
|
||||
ChatflowConversation.tenant_id == tenant_id,
|
||||
ChatflowConversation.app_id == app_id
|
||||
)
|
||||
)
|
||||
|
||||
if node_id:
|
||||
stmt = stmt.where(ChatflowConversation.node_id == node_id)
|
||||
else:
|
||||
stmt = stmt.where(ChatflowConversation.node_id.is_(None))
|
||||
|
||||
chatflow_conv: Row[tuple[ChatflowConversation]] | None = session.execute(stmt).first()
|
||||
|
||||
if chatflow_conv:
|
||||
result: ChatflowConversation = chatflow_conv[0] # Extract the ChatflowConversation object
|
||||
return result
|
||||
else:
|
||||
if create_if_missing:
|
||||
# Create a new chatflow conversation
|
||||
default_metadata = ChatflowConversationMetadata(visible_count=0)
|
||||
new_chatflow_conv = ChatflowConversation(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
original_conversation_id=conversation_id,
|
||||
conversation_metadata=default_metadata.model_dump_json(),
|
||||
)
|
||||
session.add(new_chatflow_conv)
|
||||
session.flush() # Obtain ID
|
||||
return new_chatflow_conv
|
||||
return None
|
||||
680
api/services/chatflow_memory_service.py
Normal file
680
api/services/chatflow_memory_service.py
Normal file
@ -0,0 +1,680 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.entities import (
|
||||
MemoryBlock,
|
||||
MemoryBlockSpec,
|
||||
MemoryBlockWithConversation,
|
||||
MemoryCreatedBy,
|
||||
MemoryScheduleMode,
|
||||
MemoryScope,
|
||||
MemoryTerm,
|
||||
MemoryValueData,
|
||||
)
|
||||
from core.memory.errors import MemorySyncTimeoutError
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.variables.segments import VersionedMemoryValue
|
||||
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import App, CreatorUserRole
|
||||
from models.chatflow_memory import ChatflowMemoryVariable
|
||||
from models.workflow import Workflow, WorkflowDraftVariable
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatflowMemoryService:
|
||||
@staticmethod
|
||||
def get_persistent_memories(
|
||||
app: App,
|
||||
created_by: MemoryCreatedBy,
|
||||
version: int | None = None
|
||||
) -> Sequence[MemoryBlock]:
|
||||
if created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by_id = created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by_id = created_by.id
|
||||
if version is None:
|
||||
# If version not specified, get the latest version
|
||||
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.conversation_id == None,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id,
|
||||
)
|
||||
).order_by(ChatflowMemoryVariable.version.desc())
|
||||
else:
|
||||
stmt = select(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.conversation_id == None,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id,
|
||||
ChatflowMemoryVariable.version == version
|
||||
)
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
db_results = session.execute(stmt).all()
|
||||
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
|
||||
|
||||
@staticmethod
|
||||
def get_session_memories(
|
||||
app: App,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: str,
|
||||
version: int | None = None
|
||||
) -> Sequence[MemoryBlock]:
|
||||
if version is None:
|
||||
# If version not specified, get the latest version
|
||||
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.conversation_id == conversation_id
|
||||
)
|
||||
).order_by(ChatflowMemoryVariable.version.desc())
|
||||
else:
|
||||
stmt = select(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.conversation_id == conversation_id,
|
||||
ChatflowMemoryVariable.version == version
|
||||
)
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
db_results = session.execute(stmt).all()
|
||||
return ChatflowMemoryService._convert_to_memory_blocks(app, created_by, [result[0] for result in db_results])
|
||||
|
||||
@staticmethod
|
||||
def save_memory(memory: MemoryBlock, variable_pool: VariablePool, is_draft: bool) -> None:
|
||||
key = f"{memory.node_id}.{memory.spec.id}" if memory.node_id else memory.spec.id
|
||||
variable_pool.add([MEMORY_BLOCK_VARIABLE_NODE_ID, key], memory.value)
|
||||
if memory.created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by = memory.created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by = memory.created_by.id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
session.add(
|
||||
ChatflowMemoryVariable(
|
||||
memory_id=memory.spec.id,
|
||||
tenant_id=memory.tenant_id,
|
||||
app_id=memory.app_id,
|
||||
node_id=memory.node_id,
|
||||
conversation_id=memory.conversation_id,
|
||||
name=memory.spec.name,
|
||||
value=MemoryValueData(
|
||||
value=memory.value,
|
||||
edited_by_user=memory.edited_by_user
|
||||
).model_dump_json(),
|
||||
term=memory.spec.term,
|
||||
scope=memory.spec.scope,
|
||||
version=memory.version, # Use version from MemoryBlock directly
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if is_draft:
|
||||
with Session(bind=db.engine) as session:
|
||||
draft_var_service = WorkflowDraftVariableService(session)
|
||||
memory_selector = memory.spec.id if not memory.node_id else f"{memory.node_id}.{memory.spec.id}"
|
||||
existing_vars = draft_var_service.get_draft_variables_by_selectors(
|
||||
app_id=memory.app_id,
|
||||
selectors=[[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_selector]]
|
||||
)
|
||||
if existing_vars:
|
||||
draft_var = existing_vars[0]
|
||||
draft_var.value = VersionedMemoryValue.model_validate_json(draft_var.value)\
|
||||
.add_version(memory.value)\
|
||||
.model_dump_json()
|
||||
else:
|
||||
draft_var = WorkflowDraftVariable.new_memory_block_variable(
|
||||
app_id=memory.app_id,
|
||||
memory_id=memory.spec.id,
|
||||
name=memory.spec.name,
|
||||
value=VersionedMemoryValue().add_version(memory.value),
|
||||
description=memory.spec.description
|
||||
)
|
||||
session.add(draft_var)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_memories_by_specs(
|
||||
memory_block_specs: Sequence[MemoryBlockSpec],
|
||||
tenant_id: str, app_id: str,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
is_draft: bool
|
||||
) -> Sequence[MemoryBlock]:
|
||||
return [ChatflowMemoryService.get_memory_by_spec(
|
||||
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
|
||||
) for spec in memory_block_specs]
|
||||
|
||||
@staticmethod
|
||||
def get_memory_by_spec(
|
||||
spec: MemoryBlockSpec,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
is_draft: bool
|
||||
) -> MemoryBlock:
|
||||
with Session(db.engine) as session:
|
||||
if is_draft:
|
||||
draft_var_service = WorkflowDraftVariableService(session)
|
||||
selector = [MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"] \
|
||||
if node_id else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
|
||||
draft_vars = draft_var_service.get_draft_variables_by_selectors(
|
||||
app_id=app_id,
|
||||
selectors=[selector]
|
||||
)
|
||||
if draft_vars:
|
||||
draft_var = draft_vars[0]
|
||||
return MemoryBlock(
|
||||
value=draft_var.get_value().text,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
spec=spec,
|
||||
created_by=created_by,
|
||||
version=1,
|
||||
)
|
||||
stmt = select(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.memory_id == spec.id,
|
||||
ChatflowMemoryVariable.tenant_id == tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app_id,
|
||||
ChatflowMemoryVariable.node_id ==
|
||||
(node_id if spec.scope == MemoryScope.NODE else None),
|
||||
ChatflowMemoryVariable.conversation_id ==
|
||||
(conversation_id if spec.term == MemoryTerm.SESSION else None),
|
||||
)
|
||||
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
|
||||
result = session.execute(stmt).scalar()
|
||||
if result:
|
||||
memory_value_data = MemoryValueData.model_validate_json(result.value)
|
||||
return MemoryBlock(
|
||||
value=memory_value_data.value,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
spec=spec,
|
||||
edited_by_user=memory_value_data.edited_by_user,
|
||||
created_by=created_by,
|
||||
version=result.version,
|
||||
)
|
||||
return MemoryBlock(
|
||||
tenant_id=tenant_id,
|
||||
value=spec.template,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
spec=spec,
|
||||
created_by=created_by,
|
||||
version=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_app_memory_if_needed(
|
||||
workflow: Workflow,
|
||||
conversation_id: str,
|
||||
variable_pool: VariablePool,
|
||||
created_by: MemoryCreatedBy,
|
||||
is_draft: bool
|
||||
):
|
||||
visible_messages = ChatflowHistoryService.get_visible_chat_history(
|
||||
conversation_id=conversation_id,
|
||||
app_id=workflow.app_id,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_id=None,
|
||||
)
|
||||
sync_blocks: list[MemoryBlock] = []
|
||||
async_blocks: list[MemoryBlock] = []
|
||||
for memory_spec in workflow.memory_blocks:
|
||||
if memory_spec.scope == MemoryScope.APP:
|
||||
memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=None,
|
||||
is_draft=is_draft,
|
||||
created_by=created_by,
|
||||
)
|
||||
if ChatflowMemoryService._should_update_memory(memory, visible_messages):
|
||||
if memory.spec.schedule_mode == MemoryScheduleMode.SYNC:
|
||||
sync_blocks.append(memory)
|
||||
else:
|
||||
async_blocks.append(memory)
|
||||
|
||||
if not sync_blocks and not async_blocks:
|
||||
return
|
||||
|
||||
# async mode: submit individual async tasks directly
|
||||
for memory_block in async_blocks:
|
||||
ChatflowMemoryService._app_submit_async_memory_update(
|
||||
block=memory_block,
|
||||
is_draft=is_draft,
|
||||
variable_pool=variable_pool,
|
||||
visible_messages=visible_messages,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# sync mode: submit a batch update task
|
||||
if sync_blocks:
|
||||
ChatflowMemoryService._app_submit_sync_memory_batch_update(
|
||||
sync_blocks=sync_blocks,
|
||||
is_draft=is_draft,
|
||||
conversation_id=conversation_id,
|
||||
app_id=workflow.app_id,
|
||||
visible_messages=visible_messages,
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_node_memory_if_needed(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: str,
|
||||
memory_block_spec: MemoryBlockSpec,
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool
|
||||
) -> bool:
|
||||
visible_messages = ChatflowHistoryService.get_visible_chat_history(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
memory_block = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_block_spec,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=is_draft,
|
||||
created_by=created_by,
|
||||
)
|
||||
if not ChatflowMemoryService._should_update_memory(
|
||||
memory_block=memory_block,
|
||||
visible_history=visible_messages
|
||||
):
|
||||
return False
|
||||
|
||||
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
|
||||
# Node-level sync: blocking execution
|
||||
ChatflowMemoryService._update_node_memory_sync(
|
||||
visible_messages=visible_messages,
|
||||
memory_block=memory_block,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
else:
|
||||
# Node-level async: execute asynchronously
|
||||
ChatflowMemoryService._update_node_memory_async(
|
||||
memory_block=memory_block,
|
||||
visible_messages=visible_messages,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync_memory_completion(workflow: Workflow, conversation_id: str):
|
||||
"""Wait for sync memory update to complete, maximum 50 seconds"""
|
||||
|
||||
memory_blocks = workflow.memory_blocks
|
||||
sync_memory_blocks = [
|
||||
block for block in memory_blocks
|
||||
if block.scope == MemoryScope.APP and block.schedule_mode == MemoryScheduleMode.SYNC
|
||||
]
|
||||
|
||||
if not sync_memory_blocks:
|
||||
return
|
||||
|
||||
lock_key = _get_memory_sync_lock_key(workflow.app_id, conversation_id)
|
||||
|
||||
# Retry up to 10 times, wait 5 seconds each time, total 50 seconds
|
||||
max_retries = 10
|
||||
retry_interval = 5
|
||||
|
||||
for i in range(max_retries):
|
||||
if not redis_client.exists(lock_key):
|
||||
# Lock doesn't exist, can continue
|
||||
return
|
||||
|
||||
if i < max_retries - 1:
|
||||
# Still have retry attempts, wait
|
||||
time.sleep(retry_interval)
|
||||
else:
|
||||
# Maximum retry attempts reached, raise exception
|
||||
raise MemorySyncTimeoutError(
|
||||
app_id=workflow.app_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_memory_blocks(
|
||||
app: App,
|
||||
created_by: MemoryCreatedBy,
|
||||
raw_results: Sequence[ChatflowMemoryVariable]
|
||||
) -> Sequence[MemoryBlock]:
|
||||
workflow = WorkflowService().get_published_workflow(app)
|
||||
if not workflow:
|
||||
return []
|
||||
results = []
|
||||
for chatflow_memory_variable in raw_results:
|
||||
spec = next(
|
||||
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id),
|
||||
None
|
||||
)
|
||||
if spec and chatflow_memory_variable.app_id:
|
||||
memory_value_data = MemoryValueData.model_validate_json(chatflow_memory_variable.value)
|
||||
results.append(
|
||||
MemoryBlock(
|
||||
spec=spec,
|
||||
tenant_id=chatflow_memory_variable.tenant_id,
|
||||
value=memory_value_data.value,
|
||||
app_id=chatflow_memory_variable.app_id,
|
||||
conversation_id=chatflow_memory_variable.conversation_id,
|
||||
node_id=chatflow_memory_variable.node_id,
|
||||
edited_by_user=memory_value_data.edited_by_user,
|
||||
created_by=created_by,
|
||||
version=chatflow_memory_variable.version,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _should_update_memory(
|
||||
memory_block: MemoryBlock,
|
||||
visible_history: Sequence[PromptMessage]
|
||||
) -> bool:
|
||||
return len(visible_history) >= memory_block.spec.update_turns
|
||||
|
||||
@staticmethod
|
||||
def _app_submit_async_memory_update(
|
||||
block: MemoryBlock,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
variable_pool: VariablePool,
|
||||
conversation_id: str,
|
||||
is_draft: bool
|
||||
):
|
||||
thread = threading.Thread(
|
||||
target=ChatflowMemoryService._perform_memory_update,
|
||||
kwargs={
|
||||
'memory_block': block,
|
||||
'visible_messages': visible_messages,
|
||||
'variable_pool': variable_pool,
|
||||
'is_draft': is_draft,
|
||||
'conversation_id': conversation_id
|
||||
},
|
||||
)
|
||||
thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _app_submit_sync_memory_batch_update(
|
||||
sync_blocks: Sequence[MemoryBlock],
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool
|
||||
):
|
||||
"""Submit sync memory batch update task"""
|
||||
thread = threading.Thread(
|
||||
target=ChatflowMemoryService._batch_update_sync_memory,
|
||||
kwargs={
|
||||
'sync_blocks': sync_blocks,
|
||||
'app_id': app_id,
|
||||
'conversation_id': conversation_id,
|
||||
'visible_messages': visible_messages,
|
||||
'variable_pool': variable_pool,
|
||||
'is_draft': is_draft
|
||||
},
|
||||
)
|
||||
thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _batch_update_sync_memory(
|
||||
sync_blocks: Sequence[MemoryBlock],
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
variable_pool: VariablePool,
|
||||
is_draft: bool
|
||||
):
|
||||
try:
|
||||
lock_key = _get_memory_sync_lock_key(app_id, conversation_id)
|
||||
with redis_client.lock(lock_key, timeout=120):
|
||||
threads = []
|
||||
for block in sync_blocks:
|
||||
thread = threading.Thread(
|
||||
target=ChatflowMemoryService._perform_memory_update,
|
||||
kwargs={
|
||||
'memory_block': block,
|
||||
'visible_messages': visible_messages,
|
||||
'variable_pool': variable_pool,
|
||||
'is_draft': is_draft,
|
||||
'conversation_id': conversation_id,
|
||||
},
|
||||
)
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
except Exception as e:
|
||||
logger.exception("Error batch updating memory", exc_info=e)
|
||||
|
||||
@staticmethod
|
||||
def _update_node_memory_sync(
|
||||
memory_block: MemoryBlock,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
variable_pool: VariablePool,
|
||||
conversation_id: str,
|
||||
is_draft: bool
|
||||
):
|
||||
ChatflowMemoryService._perform_memory_update(
|
||||
memory_block=memory_block,
|
||||
visible_messages=visible_messages,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_node_memory_async(
|
||||
memory_block: MemoryBlock,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
variable_pool: VariablePool,
|
||||
conversation_id: str,
|
||||
is_draft: bool = False
|
||||
):
|
||||
thread = threading.Thread(
|
||||
target=ChatflowMemoryService._perform_memory_update,
|
||||
kwargs={
|
||||
'memory_block': memory_block,
|
||||
'visible_messages': visible_messages,
|
||||
'variable_pool': variable_pool,
|
||||
'is_draft': is_draft,
|
||||
'conversation_id': conversation_id,
|
||||
},
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _perform_memory_update(
|
||||
memory_block: MemoryBlock,
|
||||
variable_pool: VariablePool,
|
||||
conversation_id: str,
|
||||
visible_messages: Sequence[PromptMessage],
|
||||
is_draft: bool
|
||||
):
|
||||
updated_value = LLMGenerator.update_memory_block(
|
||||
tenant_id=memory_block.tenant_id,
|
||||
visible_history=ChatflowMemoryService._format_chat_history(visible_messages),
|
||||
variable_pool=variable_pool,
|
||||
memory_block=memory_block,
|
||||
memory_spec=memory_block.spec,
|
||||
)
|
||||
updated_memory = MemoryBlock(
|
||||
tenant_id=memory_block.tenant_id,
|
||||
value=updated_value,
|
||||
spec=memory_block.spec,
|
||||
app_id=memory_block.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=memory_block.node_id,
|
||||
edited_by_user=False,
|
||||
created_by=memory_block.created_by,
|
||||
version=memory_block.version + 1, # Increment version for business logic update
|
||||
)
|
||||
ChatflowMemoryService.save_memory(updated_memory, variable_pool, is_draft)
|
||||
|
||||
ChatflowHistoryService.update_visible_count(
|
||||
conversation_id=conversation_id,
|
||||
node_id=memory_block.node_id,
|
||||
new_visible_count=memory_block.spec.preserved_turns,
|
||||
app_id=memory_block.app_id,
|
||||
tenant_id=memory_block.tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_memory(app: App, memory_id: str, created_by: MemoryCreatedBy):
|
||||
workflow = WorkflowService().get_published_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == memory_id), None)
|
||||
if not memory_spec or not memory_spec.end_user_editable:
|
||||
raise ValueError("Memory not found or not deletable")
|
||||
|
||||
if created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by_id = created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by_id = created_by.id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = delete(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.memory_id == memory_id,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_all_user_memories(app: App, created_by: MemoryCreatedBy):
|
||||
if created_by.account_id:
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by_id = created_by.account_id
|
||||
else:
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by_id = created_by.id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = delete(ChatflowMemoryVariable).where(
|
||||
and_(
|
||||
ChatflowMemoryVariable.tenant_id == app.tenant_id,
|
||||
ChatflowMemoryVariable.app_id == app.id,
|
||||
ChatflowMemoryVariable.created_by_role == created_by_role,
|
||||
ChatflowMemoryVariable.created_by == created_by_id
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_persistent_memories_with_conversation(
|
||||
app: App,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: str,
|
||||
version: int | None = None
|
||||
) -> Sequence[MemoryBlockWithConversation]:
|
||||
"""Get persistent memories with conversation metadata (always None for persistent)"""
|
||||
memory_blocks = ChatflowMemoryService.get_persistent_memories(app, created_by, version)
|
||||
return [
|
||||
MemoryBlockWithConversation.from_memory_block(
|
||||
block,
|
||||
ChatflowHistoryService.get_conversation_metadata(
|
||||
app.tenant_id, app.id, conversation_id, block.node_id
|
||||
)
|
||||
)
|
||||
for block in memory_blocks
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_session_memories_with_conversation(
|
||||
app: App,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: str,
|
||||
version: int | None = None
|
||||
) -> Sequence[MemoryBlockWithConversation]:
|
||||
"""Get session memories with conversation metadata"""
|
||||
memory_blocks = ChatflowMemoryService.get_session_memories(app, created_by, conversation_id, version)
|
||||
return [
|
||||
MemoryBlockWithConversation.from_memory_block(
|
||||
block,
|
||||
ChatflowHistoryService.get_conversation_metadata(
|
||||
app.tenant_id, app.id, conversation_id, block.node_id
|
||||
)
|
||||
)
|
||||
for block in memory_blocks
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _format_chat_history(messages: Sequence[PromptMessage]) -> Sequence[tuple[str, str]]:
|
||||
result = []
|
||||
for message in messages:
|
||||
result.append((str(message.role.value), message.get_text_content()))
|
||||
return result
|
||||
|
||||
|
||||
def _get_memory_sync_lock_key(app_id: str, conversation_id: str) -> str:
|
||||
"""Generate Redis lock key for memory sync updates
|
||||
|
||||
Args:
|
||||
app_id: Application ID
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Formatted lock key
|
||||
"""
|
||||
return f"memory_sync_update:{app_id}:{conversation_id}"
|
||||
68
api/services/credit_pool_service.py
Normal file
68
api/services/credit_pool_service.py
Normal file
@ -0,0 +1,68 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@classmethod
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
credit_pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
)
|
||||
db.session.add(credit_pool)
|
||||
db.session.commit()
|
||||
return credit_pool
|
||||
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_and_deduct_credits(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
):
|
||||
"""check and deduct credits"""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
if pool.remaining_credits < credits_required:
|
||||
raise QuotaExceededError(
|
||||
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
where_conditions = [
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel):
|
||||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
|
||||
class FeatureService:
|
||||
@ -214,6 +216,8 @@ class FeatureService:
|
||||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_env(cls, features: FeatureModel):
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services.feature_service import FeatureService
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
|
||||
|
||||
@ -20,6 +25,15 @@ class RecommendedAppService:
|
||||
)
|
||||
)
|
||||
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
apps = result["recommended_apps"]
|
||||
for app in apps:
|
||||
app_id = app["app_id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
app["can_trial"] = True
|
||||
else:
|
||||
app["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -32,4 +46,27 @@ class RecommendedAppService:
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
result["can_trial"] = True
|
||||
else:
|
||||
result["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def add_trial_app_record(cls, app_id: str, account_id: str):
|
||||
"""
|
||||
Add trial app record.
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first()
|
||||
if account_trial_app_record:
|
||||
account_trial_app_record.count += 1
|
||||
session.commit()
|
||||
else:
|
||||
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||
session.commit()
|
||||
|
||||
311
api/services/workflow_comment_service.py
Normal file
311
api/services/workflow_comment_service.py
Normal file
@ -0,0 +1,311 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentService:
|
||||
"""Service for managing workflow comments."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_content(content: str) -> None:
|
||||
if len(content.strip()) == 0:
|
||||
raise ValueError("Comment content cannot be empty")
|
||||
|
||||
if len(content) > 1000:
|
||||
raise ValueError("Comment content cannot exceed 1000 characters")
|
||||
|
||||
@staticmethod
|
||||
def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]:
|
||||
"""Get all comments for a workflow."""
|
||||
with Session(db.engine) as session:
|
||||
# Get all comments with eager loading
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
|
||||
.order_by(desc(WorkflowComment.created_at))
|
||||
)
|
||||
|
||||
comments = session.scalars(stmt).all()
|
||||
return comments
|
||||
|
||||
@staticmethod
|
||||
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment:
|
||||
"""Get a specific comment."""
|
||||
|
||||
def _get_comment(session: Session) -> WorkflowComment:
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
return comment
|
||||
|
||||
if session is not None:
|
||||
return _get_comment(session)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return _get_comment(session)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
created_by: str,
|
||||
content: str,
|
||||
position_x: float,
|
||||
position_y: float,
|
||||
mentioned_user_ids: Optional[list[str]] = None,
|
||||
) -> WorkflowComment:
|
||||
"""Create a new workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
comment = WorkflowComment(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
position_x=position_x,
|
||||
position_y=position_y,
|
||||
content=content,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
session.add(comment)
|
||||
session.flush() # Get the comment ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id in mentioned_user_ids:
|
||||
if isinstance(user_id, str) and uuid_value(user_id):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention, not reply mention
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Return only what we need - id and created_at
|
||||
return {"id": comment.id, "created_at": comment.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
comment_id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
position_x: Optional[float] = None,
|
||||
position_y: Optional[float] = None,
|
||||
mentioned_user_ids: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
"""Update a workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get comment with validation
|
||||
stmt = select(WorkflowComment).where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Only the creator can update the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can update it")
|
||||
|
||||
# Update comment fields
|
||||
comment.content = content
|
||||
if position_x is not None:
|
||||
comment.position_x = position_x
|
||||
if position_y is not None:
|
||||
comment.position_y = position_y
|
||||
|
||||
# Update mentions - first remove existing mentions for this comment only (not replies)
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(
|
||||
WorkflowCommentMention.comment_id == comment.id,
|
||||
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
|
||||
)
|
||||
).all()
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add new mentions
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id_str in mentioned_user_ids:
|
||||
if isinstance(user_id_str, str) and uuid_value(user_id_str):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention
|
||||
mentioned_user_id=user_id_str,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"id": comment.id, "updated_at": comment.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
|
||||
"""Delete a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
|
||||
# Only the creator can delete the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can delete it")
|
||||
|
||||
# Delete associated mentions (both comment and reply mentions)
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Delete associated replies
|
||||
replies = session.scalars(
|
||||
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
|
||||
).all()
|
||||
for reply in replies:
|
||||
session.delete(reply)
|
||||
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
|
||||
"""Resolve a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
if comment.resolved:
|
||||
return comment
|
||||
|
||||
comment.resolved = True
|
||||
comment.resolved_at = naive_utc_now()
|
||||
comment.resolved_by = user_id
|
||||
session.commit()
|
||||
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def create_reply(
|
||||
comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None
|
||||
) -> dict:
|
||||
"""Add a reply to a workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Check if comment exists
|
||||
comment = session.get(WorkflowComment, comment_id)
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
|
||||
|
||||
session.add(reply)
|
||||
session.flush() # Get the reply ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id in mentioned_user_ids:
|
||||
if isinstance(user_id, str) and uuid_value(user_id):
|
||||
# Create mention linking to specific reply
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"id": reply.id, "created_at": reply.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_reply(
|
||||
reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None
|
||||
) -> WorkflowCommentReply:
|
||||
"""Update a comment reply."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can update the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can update it")
|
||||
|
||||
reply.content = content
|
||||
|
||||
# Update mentions - first remove existing mentions for this reply
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
|
||||
).all()
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add mentions
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id_str in mentioned_user_ids:
|
||||
if isinstance(user_id_str, str) and uuid_value(user_id_str):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
session.refresh(reply) # Refresh to get updated timestamp
|
||||
|
||||
return {"id": reply.id, "updated_at": reply.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_reply(reply_id: str, user_id: str) -> None:
|
||||
"""Delete a comment reply."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can delete the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can delete it")
|
||||
|
||||
# Delete associated mentions first
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
session.delete(reply)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
|
||||
"""Validate that a comment belongs to the specified tenant and app."""
|
||||
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)
|
||||
@ -11,6 +11,7 @@ from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.memory.entities import MemoryBlockSpec, MemoryCreatedBy, MemoryScope
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -196,15 +197,18 @@ class WorkflowService:
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
|
||||
force_upload: bool = False,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
:param force_upload: Skip hash validation when True (for restore operations)
|
||||
:raises WorkflowHashNotEqualError
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
if workflow and workflow.unique_hash != unique_hash and not force_upload:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
@ -223,6 +227,7 @@ class WorkflowService:
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
else:
|
||||
@ -232,6 +237,7 @@ class WorkflowService:
|
||||
workflow.updated_at = naive_utc_now()
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
@ -242,6 +248,78 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def update_draft_workflow_environment_variables(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
environment_variables: Sequence[Variable],
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow environment variables
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def update_draft_workflow_conversation_variables(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
conversation_variables: Sequence[Variable],
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow conversation variables
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def update_draft_workflow_features(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
features: dict,
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow features
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
@ -279,6 +357,7 @@ class WorkflowService:
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
memory_blocks=draft_workflow.memory_blocks,
|
||||
features=draft_workflow.features,
|
||||
)
|
||||
|
||||
@ -635,17 +714,10 @@ class WorkflowService:
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
# init variable pool
|
||||
variable_pool = _setup_variable_pool(
|
||||
query=query,
|
||||
files=files or [],
|
||||
user_id=account.id,
|
||||
user_inputs=user_inputs,
|
||||
workflow=draft_workflow,
|
||||
# NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
|
||||
conversation_variables=[],
|
||||
node_type=node_type,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
variable_pool = _setup_variable_pool(query=query, files=files or [], user_id=account.id,
|
||||
user_inputs=user_inputs, workflow=draft_workflow,
|
||||
node_type=node_type, conversation_id=conversation_id,
|
||||
conversation_variables=[], is_draft=True)
|
||||
|
||||
else:
|
||||
variable_pool = VariablePool(
|
||||
@ -994,6 +1066,7 @@ def _setup_variable_pool(
|
||||
node_type: NodeType,
|
||||
conversation_id: str,
|
||||
conversation_variables: list[Variable],
|
||||
is_draft: bool
|
||||
):
|
||||
# Only inject system variables for START node type.
|
||||
if node_type == NodeType.START:
|
||||
@ -1012,7 +1085,6 @@ def _setup_variable_pool(
|
||||
system_variable.dialogue_count = 1
|
||||
else:
|
||||
system_variable = SystemVariable.empty()
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_variable,
|
||||
@ -1021,6 +1093,12 @@ def _setup_variable_pool(
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables), #
|
||||
memory_blocks=_fetch_memory_blocks(
|
||||
workflow,
|
||||
MemoryCreatedBy(account_id=user_id),
|
||||
conversation_id,
|
||||
is_draft=is_draft
|
||||
),
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
@ -1057,3 +1135,30 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia
|
||||
return build_from_mappings(mappings=value, tenant_id=tenant_id)
|
||||
else:
|
||||
raise Exception("unreachable")
|
||||
|
||||
|
||||
def _fetch_memory_blocks(
|
||||
workflow: Workflow,
|
||||
created_by: MemoryCreatedBy,
|
||||
conversation_id: str,
|
||||
is_draft: bool
|
||||
) -> Mapping[str, str]:
|
||||
memory_blocks = {}
|
||||
memory_block_specs = workflow.memory_blocks
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
memories = ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft,
|
||||
created_by=created_by,
|
||||
)
|
||||
for memory in memories:
|
||||
if memory.spec.scope == MemoryScope.APP:
|
||||
memory_blocks[memory.spec.id] = memory.value
|
||||
else: # NODE scope
|
||||
memory_blocks[f"{memory.node_id}.{memory.spec.id}"] = memory.value
|
||||
|
||||
return memory_blocks
|
||||
|
||||
@ -46,5 +46,17 @@ class WorkspaceService:
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
|
||||
if paid_pool:
|
||||
tenant_info["trial_credits"] = paid_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = paid_pool.quota_used
|
||||
else:
|
||||
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
|
||||
if trial_pool:
|
||||
tenant_info["trial_credits"] = trial_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = trial_pool.quota_used
|
||||
|
||||
return tenant_info
|
||||
|
||||
4803
api/uv.lock
generated
4803
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,14 @@ server {
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /socket.io/ {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
}
|
||||
|
||||
location /v1 {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
|
||||
@ -5,7 +5,7 @@ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
# proxy_set_header Connection "";
|
||||
proxy_buffering off;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT};
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT};
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import AppCard from '@/app/components/app/overview/app-card'
|
||||
@ -19,6 +19,8 @@ import { asyncRunSafe } from '@/utils'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import type { IAppCardProps } from '@/app/components/app/overview/app-card'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
|
||||
|
||||
export type ICardViewProps = {
|
||||
appId: string
|
||||
@ -47,15 +49,44 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
|
||||
|
||||
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
|
||||
|
||||
if (type === 'success')
|
||||
if (type === 'success') {
|
||||
updateAppDetail()
|
||||
|
||||
// Emit collaboration event to notify other clients of app state changes
|
||||
const socket = webSocketClient.getSocket(appId)
|
||||
if (socket) {
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'appStateUpdate',
|
||||
data: { timestamp: Date.now() },
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
notify({
|
||||
type,
|
||||
message: t(`common.actionMsg.${message}`),
|
||||
})
|
||||
}
|
||||
|
||||
// Listen for collaborative app state updates from other clients
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
const unsubscribe = collaborationManager.onAppStateUpdate(async (update: any) => {
|
||||
try {
|
||||
console.log('Received app state update from collaboration:', update)
|
||||
// Update app detail when other clients modify app state
|
||||
await updateAppDetail()
|
||||
}
|
||||
catch (error) {
|
||||
console.error('app state update failed:', error)
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [appId])
|
||||
|
||||
const onChangeSiteStatus = async (value: boolean) => {
|
||||
const [err] = await asyncRunSafe<App>(
|
||||
updateAppSiteStatus({
|
||||
|
||||
@ -32,6 +32,8 @@ import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { formatTime } from '@/utils/time'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import dynamic from 'next/dynamic'
|
||||
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
|
||||
import type { WorkflowOnlineUser } from '@/models/app'
|
||||
|
||||
const EditAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
|
||||
ssr: false,
|
||||
@ -55,9 +57,10 @@ const AccessControl = dynamic(() => import('@/app/components/app/app-access-cont
|
||||
export type AppCardProps = {
|
||||
app: App
|
||||
onRefresh?: () => void
|
||||
onlineUsers?: WorkflowOnlineUser[]
|
||||
}
|
||||
|
||||
const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
@ -333,6 +336,19 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
|
||||
}, [app.updated_at, app.created_at])
|
||||
|
||||
const onlineUserAvatars = useMemo(() => {
|
||||
if (!onlineUsers.length)
|
||||
return []
|
||||
|
||||
return onlineUsers
|
||||
.map(user => ({
|
||||
id: user.user_id || user.sid || '',
|
||||
name: user.username || 'User',
|
||||
avatar_url: user.avatar || undefined,
|
||||
}))
|
||||
.filter(user => !!user.id)
|
||||
}, [onlineUsers])
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
@ -377,6 +393,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
<RiVerifiedBadgeLine className='h-4 w-4 text-text-quaternary' />
|
||||
</Tooltip>}
|
||||
</div>
|
||||
<div>
|
||||
{onlineUserAvatars.length > 0 && (
|
||||
<UserAvatarList users={onlineUserAvatars} maxVisible={3} size={20} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
|
||||
<div
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import {
|
||||
useRouter,
|
||||
} from 'next/navigation'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import {
|
||||
@ -19,8 +20,8 @@ import AppCard from './app-card'
|
||||
import NewAppCard from './new-app-card'
|
||||
import useAppsQueryState from './hooks/use-apps-query-state'
|
||||
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
|
||||
import type { AppListResponse } from '@/models/app'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import type { AppListResponse, WorkflowOnlineUser } from '@/models/app'
|
||||
import { fetchAppList, fetchWorkflowOnlineUsers } from '@/service/apps'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { CheckModal } from '@/hooks/use-pay'
|
||||
@ -112,6 +113,37 @@ const List = () => {
|
||||
},
|
||||
)
|
||||
|
||||
const apps = useMemo(() => data?.flatMap(page => page.data) ?? [], [data])
|
||||
|
||||
const workflowIds = useMemo(() => {
|
||||
const ids = new Set<string>()
|
||||
apps.forEach((appItem) => {
|
||||
const workflowId = appItem.id
|
||||
if (!workflowId)
|
||||
return
|
||||
|
||||
if (appItem.mode === 'workflow' || appItem.mode === 'advanced-chat')
|
||||
ids.add(workflowId)
|
||||
})
|
||||
return Array.from(ids)
|
||||
}, [apps])
|
||||
|
||||
const { data: onlineUsersByWorkflow, mutate: refreshOnlineUsers } = useSWR<Record<string, WorkflowOnlineUser[]>>(
|
||||
workflowIds.length ? { workflowIds } : null,
|
||||
fetchWorkflowOnlineUsers,
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (!workflowIds.length)
|
||||
return
|
||||
|
||||
const timer = window.setInterval(() => {
|
||||
refreshOnlineUsers()
|
||||
}, 10000)
|
||||
|
||||
return () => window.clearInterval(timer)
|
||||
}, [workflowIds.join(','), refreshOnlineUsers])
|
||||
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const options = [
|
||||
{ value: 'all', text: t('app.types.all'), icon: <RiApps2Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
@ -213,7 +245,12 @@ const List = () => {
|
||||
{isCurrentWorkspaceEditor
|
||||
&& <NewAppCard ref={newAppCardRef} onSuccess={mutate} selectedAppType={activeTab} />}
|
||||
{data.map(({ data: apps }) => apps.map(app => (
|
||||
<AppCard key={app.id} app={app} onRefresh={mutate} />
|
||||
<AppCard
|
||||
key={app.id}
|
||||
app={app}
|
||||
onRefresh={mutate}
|
||||
onlineUsers={onlineUsersByWorkflow?.[app.id] ?? []}
|
||||
/>
|
||||
)))}
|
||||
</div>
|
||||
: <div className='relative grid grow grid-cols-1 content-start gap-4 overflow-hidden px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>
|
||||
|
||||
@ -9,6 +9,7 @@ export type AvatarProps = {
|
||||
className?: string
|
||||
textClassName?: string
|
||||
onError?: (x: boolean) => void
|
||||
backgroundColor?: string
|
||||
}
|
||||
const Avatar = ({
|
||||
name,
|
||||
@ -17,9 +18,18 @@ const Avatar = ({
|
||||
className,
|
||||
textClassName,
|
||||
onError,
|
||||
backgroundColor,
|
||||
}: AvatarProps) => {
|
||||
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
|
||||
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
|
||||
const avatarClassName = backgroundColor
|
||||
? 'shrink-0 flex items-center rounded-full'
|
||||
: 'shrink-0 flex items-center rounded-full bg-primary-600'
|
||||
const style = {
|
||||
width: `${size}px`,
|
||||
height: `${size}px`,
|
||||
fontSize: `${size}px`,
|
||||
lineHeight: `${size}px`,
|
||||
...(backgroundColor && !avatar ? { backgroundColor } : {}),
|
||||
}
|
||||
const [imgError, setImgError] = useState(false)
|
||||
|
||||
const handleError = () => {
|
||||
@ -35,14 +45,18 @@ const Avatar = ({
|
||||
|
||||
if (avatar && !imgError) {
|
||||
return (
|
||||
<img
|
||||
<span
|
||||
className={cn(avatarClassName, className)}
|
||||
style={style}
|
||||
alt={name}
|
||||
src={avatar}
|
||||
onError={handleError}
|
||||
onLoad={() => onError?.(false)}
|
||||
/>
|
||||
>
|
||||
<img
|
||||
className='h-full w-full rounded-full object-cover'
|
||||
alt={name}
|
||||
src={avatar}
|
||||
onError={handleError}
|
||||
onLoad={() => onError?.(false)}
|
||||
/>
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="12" viewBox="0 0 14 12" fill="none">
|
||||
<path d="M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 527 B |
26
web/app/components/base/icons/src/public/other/Comment.json
Normal file
26
web/app/components/base/icons/src/public/other/Comment.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"xmlns": "http://www.w3.org/2000/svg",
|
||||
"width": "14",
|
||||
"height": "12",
|
||||
"viewBox": "0 0 14 12",
|
||||
"fill": "none"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Comment"
|
||||
}
|
||||
20
web/app/components/base/icons/src/public/other/Comment.tsx
Normal file
20
web/app/components/base/icons/src/public/other/Comment.tsx
Normal file
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './Comment.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'Comment'
|
||||
|
||||
export default Icon
|
||||
@ -1,4 +1,5 @@
|
||||
export { default as Icon3Dots } from './Icon3Dots'
|
||||
export { default as Comment } from './Comment'
|
||||
export { default as DefaultToolIcon } from './DefaultToolIcon'
|
||||
export { default as Message3Fill } from './Message3Fill'
|
||||
export { default as RowStruct } from './RowStruct'
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import type {
|
||||
EditorState,
|
||||
} from 'lexical'
|
||||
@ -80,6 +81,29 @@ import {
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
useEffect(() => {
|
||||
if (value === undefined)
|
||||
return
|
||||
|
||||
const incomingValue = value ?? ''
|
||||
const shouldUpdate = editor.getEditorState().read(() => {
|
||||
const currentText = $getRoot().getChildren().map(node => node.getTextContent()).join('\n')
|
||||
return currentText !== incomingValue
|
||||
})
|
||||
|
||||
if (!shouldUpdate)
|
||||
return
|
||||
|
||||
const editorState = editor.parseEditorState(textToEditorState(incomingValue))
|
||||
editor.setEditorState(editorState)
|
||||
}, [editor, value])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export type PromptEditorProps = {
|
||||
instanceId?: string
|
||||
compact?: boolean
|
||||
@ -293,6 +317,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
<VariableValueBlock />
|
||||
)
|
||||
}
|
||||
<ValueSyncPlugin value={value} />
|
||||
<OnChangePlugin onChange={handleEditorChange} />
|
||||
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
|
||||
<UpdateBlock instanceId={instanceId} />
|
||||
|
||||
77
web/app/components/base/user-avatar-list/index.tsx
Normal file
77
web/app/components/base/user-avatar-list/index.tsx
Normal file
@ -0,0 +1,77 @@
|
||||
import type { FC } from 'react'
|
||||
import { memo } from 'react'
|
||||
import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
|
||||
type User = {
|
||||
id: string
|
||||
name: string
|
||||
avatar_url?: string | null
|
||||
}
|
||||
|
||||
type UserAvatarListProps = {
|
||||
users: User[]
|
||||
maxVisible?: number
|
||||
size?: number
|
||||
className?: string
|
||||
showCount?: boolean
|
||||
}
|
||||
|
||||
export const UserAvatarList: FC<UserAvatarListProps> = memo(({
|
||||
users,
|
||||
maxVisible = 3,
|
||||
size = 24,
|
||||
className = '',
|
||||
showCount = true,
|
||||
}) => {
|
||||
const { userProfile } = useAppContext()
|
||||
if (!users.length) return null
|
||||
|
||||
const shouldShowCount = showCount && users.length > maxVisible
|
||||
const actualMaxVisible = shouldShowCount ? Math.max(1, maxVisible - 1) : maxVisible
|
||||
const visibleUsers = users.slice(0, actualMaxVisible)
|
||||
const remainingCount = users.length - actualMaxVisible
|
||||
|
||||
const currentUserId = userProfile?.id
|
||||
|
||||
return (
|
||||
<div className={`flex items-center -space-x-1 ${className}`}>
|
||||
{visibleUsers.map((user, index) => {
|
||||
const isCurrentUser = user.id === currentUserId
|
||||
const userColor = isCurrentUser ? undefined : getUserColor(user.id)
|
||||
return (
|
||||
<div
|
||||
key={`${user.id}-${index}`}
|
||||
className='relative'
|
||||
style={{ zIndex: visibleUsers.length - index }}
|
||||
>
|
||||
<Avatar
|
||||
name={user.name}
|
||||
avatar={user.avatar_url || null}
|
||||
size={size}
|
||||
className='ring-2 ring-components-panel-bg'
|
||||
backgroundColor={userColor}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
|
||||
)}
|
||||
{shouldShowCount && remainingCount > 0 && (
|
||||
<div
|
||||
className={'flex items-center justify-center rounded-full bg-gray-500 text-[10px] leading-none text-white ring-2 ring-components-panel-bg'}
|
||||
style={{
|
||||
zIndex: 0,
|
||||
width: size,
|
||||
height: size,
|
||||
}}
|
||||
>
|
||||
+{remainingCount}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
UserAvatarList.displayName = 'UserAvatarList'
|
||||
@ -26,6 +26,8 @@ import {
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import cn from '@/utils/classnames'
|
||||
import { fetchAppDetail } from '@/service/apps'
|
||||
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
|
||||
|
||||
export type IAppCardProps = {
|
||||
appInfo: AppDetailResponse & Partial<AppSSO>
|
||||
@ -90,6 +92,19 @@ function MCPServiceCard({
|
||||
const onGenCode = async () => {
|
||||
await refreshMCPServerCode(detail?.id || '')
|
||||
invalidateMCPServerDetail(appId)
|
||||
|
||||
// Emit collaboration event to notify other clients of MCP server changes
|
||||
const socket = webSocketClient.getSocket(appId)
|
||||
if (socket) {
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'mcpServerUpdate',
|
||||
data: {
|
||||
action: 'codeRegenerated',
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const onChangeStatus = async (state: boolean) => {
|
||||
@ -119,6 +134,20 @@ function MCPServiceCard({
|
||||
})
|
||||
invalidateMCPServerDetail(appId)
|
||||
}
|
||||
|
||||
// Emit collaboration event to notify other clients of MCP server status change
|
||||
const socket = webSocketClient.getSocket(appId)
|
||||
if (socket) {
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'mcpServerUpdate',
|
||||
data: {
|
||||
action: 'statusChanged',
|
||||
status: state ? 'active' : 'inactive',
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleServerModalHide = () => {
|
||||
@ -131,6 +160,23 @@ function MCPServiceCard({
|
||||
setActivated(serverActivated)
|
||||
}, [serverActivated])
|
||||
|
||||
// Listen for collaborative MCP server updates from other clients
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
const unsubscribe = collaborationManager.onMcpServerUpdate(async (update: any) => {
|
||||
try {
|
||||
console.log('Received MCP server update from collaboration:', update)
|
||||
invalidateMCPServerDetail(appId)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('MCP server update failed:', error)
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [appId, invalidateMCPServerDetail])
|
||||
|
||||
if (!currentWorkflow && isAdvancedApp)
|
||||
return null
|
||||
|
||||
|
||||
@ -1,11 +1,18 @@
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
import type { Features as FeaturesData } from '@/app/components/base/features/types'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { WorkflowWithInnerContext } from '@/app/components/workflow'
|
||||
import type { WorkflowProps } from '@/app/components/workflow'
|
||||
import WorkflowChildren from './workflow-children'
|
||||
|
||||
import {
|
||||
useAvailableNodesMetaData,
|
||||
useConfigsMap,
|
||||
@ -18,7 +25,12 @@ import {
|
||||
useWorkflowRun,
|
||||
useWorkflowStartRun,
|
||||
} from '../hooks'
|
||||
import { useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions'
|
||||
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import { useCollaboration } from '@/app/components/workflow/collaboration'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration'
|
||||
import { fetchWorkflowDraft } from '@/service/workflow'
|
||||
import { useReactFlow, useStoreApi } from 'reactflow'
|
||||
|
||||
type WorkflowMainProps = Pick<WorkflowProps, 'nodes' | 'edges' | 'viewport'>
|
||||
const WorkflowMain = ({
|
||||
@ -28,6 +40,31 @@ const WorkflowMain = ({
|
||||
}: WorkflowMainProps) => {
|
||||
const featuresStore = useFeaturesStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const appId = useStore(s => s.appId)
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const reactFlow = useReactFlow()
|
||||
|
||||
const store = useStoreApi()
|
||||
const { startCursorTracking, stopCursorTracking, onlineUsers, cursors, isConnected } = useCollaboration(appId || '', store)
|
||||
const [myUserId, setMyUserId] = useState<string | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected)
|
||||
setMyUserId('current-user')
|
||||
}, [isConnected])
|
||||
|
||||
const filteredCursors = Object.fromEntries(
|
||||
Object.entries(cursors).filter(([userId]) => userId !== myUserId),
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (containerRef.current)
|
||||
startCursorTracking(containerRef as React.RefObject<HTMLElement>, reactFlow)
|
||||
|
||||
return () => {
|
||||
stopCursorTracking()
|
||||
}
|
||||
}, [startCursorTracking, stopCursorTracking, reactFlow])
|
||||
|
||||
const handleWorkflowDataUpdate = useCallback((payload: any) => {
|
||||
const {
|
||||
@ -38,7 +75,33 @@ const WorkflowMain = ({
|
||||
if (features && featuresStore) {
|
||||
const { setFeatures } = featuresStore.getState()
|
||||
|
||||
setFeatures(features)
|
||||
const transformedFeatures: FeaturesData = {
|
||||
file: {
|
||||
image: {
|
||||
enabled: !!features.file_upload?.image?.enabled,
|
||||
number_limits: features.file_upload?.image?.number_limits || 3,
|
||||
transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled),
|
||||
allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3,
|
||||
},
|
||||
opening: {
|
||||
enabled: !!features.opening_statement,
|
||||
opening_statement: features.opening_statement,
|
||||
suggested_questions: features.suggested_questions,
|
||||
},
|
||||
suggested: features.suggested_questions_after_answer || { enabled: false },
|
||||
speech2text: features.speech_to_text || { enabled: false },
|
||||
text2speech: features.text_to_speech || { enabled: false },
|
||||
citation: features.retriever_resource || { enabled: false },
|
||||
moderation: features.sensitive_word_avoidance || { enabled: false },
|
||||
annotationReply: features.annotation_reply || { enabled: false },
|
||||
}
|
||||
|
||||
setFeatures(transformedFeatures)
|
||||
}
|
||||
if (conversation_variables) {
|
||||
const { setConversationVariables } = workflowStore.getState()
|
||||
@ -55,6 +118,7 @@ const WorkflowMain = ({
|
||||
syncWorkflowDraftWhenPageClose,
|
||||
} = useNodesSyncDraft()
|
||||
const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft()
|
||||
const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
|
||||
const {
|
||||
handleBackupDraft,
|
||||
handleLoadBackupDraft,
|
||||
@ -62,6 +126,63 @@ const WorkflowMain = ({
|
||||
handleRun,
|
||||
handleStopRun,
|
||||
} = useWorkflowRun()
|
||||
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
const unsubscribe = collaborationManager.onVarsAndFeaturesUpdate(async (update: any) => {
|
||||
try {
|
||||
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
|
||||
handleWorkflowDataUpdate(response)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('workflow vars and features update failed:', error)
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [appId, handleWorkflowDataUpdate])
|
||||
|
||||
// Listen for workflow updates from other users
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
const unsubscribe = collaborationManager.onWorkflowUpdate(async () => {
|
||||
console.log('Received workflow update from collaborator, fetching latest workflow data')
|
||||
try {
|
||||
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
|
||||
|
||||
// Handle features, variables etc.
|
||||
handleWorkflowDataUpdate(response)
|
||||
|
||||
// Update workflow canvas (nodes, edges, viewport)
|
||||
if (response.graph) {
|
||||
handleUpdateWorkflowCanvas({
|
||||
nodes: response.graph.nodes || [],
|
||||
edges: response.graph.edges || [],
|
||||
viewport: response.graph.viewport || { x: 0, y: 0, zoom: 1 },
|
||||
})
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Failed to fetch updated workflow:', error)
|
||||
}
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [appId, handleWorkflowDataUpdate, handleUpdateWorkflowCanvas])
|
||||
|
||||
// Listen for sync requests from other users (only processed by leader)
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
const unsubscribe = collaborationManager.onSyncRequest(() => {
|
||||
console.log('Leader received sync request, performing sync')
|
||||
doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
return unsubscribe
|
||||
}, [appId, doSyncWorkflowDraft])
|
||||
const {
|
||||
handleStartWorkflowRun,
|
||||
handleWorkflowStartRunInChatflow,
|
||||
@ -75,6 +196,7 @@ const WorkflowMain = ({
|
||||
} = useDSL()
|
||||
|
||||
const configsMap = useConfigsMap()
|
||||
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue({
|
||||
...configsMap,
|
||||
})
|
||||
@ -164,15 +286,23 @@ const WorkflowMain = ({
|
||||
])
|
||||
|
||||
return (
|
||||
<WorkflowWithInnerContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={viewport}
|
||||
onWorkflowDataUpdate={handleWorkflowDataUpdate}
|
||||
hooksStore={hooksStore as any}
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="relative h-full w-full"
|
||||
>
|
||||
<WorkflowChildren />
|
||||
</WorkflowWithInnerContext>
|
||||
<WorkflowWithInnerContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={viewport}
|
||||
onWorkflowDataUpdate={handleWorkflowDataUpdate}
|
||||
hooksStore={hooksStore as any}
|
||||
cursors={filteredCursors}
|
||||
myUserId={myUserId}
|
||||
onlineUsers={onlineUsers}
|
||||
>
|
||||
<WorkflowChildren />
|
||||
</WorkflowWithInnerContext>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import { useStore } from '@/app/components/workflow/store'
|
||||
import {
|
||||
useIsChatMode,
|
||||
} from '../hooks'
|
||||
import CommentsPanel from '@/app/components/workflow/panel/comments-panel'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import type { PanelProps } from '@/app/components/workflow/panel'
|
||||
import Panel from '@/app/components/workflow/panel'
|
||||
@ -67,6 +68,7 @@ const WorkflowPanelOnRight = () => {
|
||||
const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel)
|
||||
const showChatVariablePanel = useStore(s => s.showChatVariablePanel)
|
||||
const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel)
|
||||
const controlMode = useStore(s => s.controlMode)
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -100,6 +102,7 @@ const WorkflowPanelOnRight = () => {
|
||||
<GlobalVariablePanel />
|
||||
)
|
||||
}
|
||||
{controlMode === 'comment' && <CommentsPanel />}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@ import { syncWorkflowDraft } from '@/service/workflow'
|
||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useWorkflowRefreshDraft } from '.'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
|
||||
|
||||
export const useNodesSyncDraft = () => {
|
||||
const store = useStoreApi()
|
||||
@ -85,6 +86,7 @@ export const useNodesSyncDraft = () => {
|
||||
environment_variables: environmentVariables,
|
||||
conversation_variables: conversationVariables,
|
||||
hash: syncWorkflowDraftHash,
|
||||
_is_collaborative: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -93,9 +95,20 @@ export const useNodesSyncDraft = () => {
|
||||
const syncWorkflowDraftWhenPageClose = useCallback(() => {
|
||||
if (getNodesReadOnly())
|
||||
return
|
||||
|
||||
// Check leader status at sync time
|
||||
const currentIsLeader = collaborationManager.getIsLeader()
|
||||
|
||||
// Only allow leader to sync data
|
||||
if (!currentIsLeader) {
|
||||
console.log('Not leader, skipping sync on page close')
|
||||
return
|
||||
}
|
||||
|
||||
const postParams = getPostParams()
|
||||
|
||||
if (postParams) {
|
||||
console.log('Leader syncing workflow draft on page close')
|
||||
navigator.sendBeacon(
|
||||
`${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`,
|
||||
JSON.stringify(postParams.params),
|
||||
@ -110,9 +123,23 @@ export const useNodesSyncDraft = () => {
|
||||
onError?: () => void
|
||||
onSettled?: () => void
|
||||
},
|
||||
forceUpload?: boolean,
|
||||
) => {
|
||||
if (getNodesReadOnly())
|
||||
return
|
||||
|
||||
// Check leader status at sync time
|
||||
const currentIsLeader = collaborationManager.getIsLeader()
|
||||
|
||||
// If not leader and not forcing upload, request the leader to sync
|
||||
if (!currentIsLeader && !forceUpload) {
|
||||
console.log('Not leader, requesting leader to sync workflow draft')
|
||||
collaborationManager.emitSyncRequest()
|
||||
callback?.onSettled?.()
|
||||
return
|
||||
}
|
||||
|
||||
console.log(forceUpload ? 'Force uploading workflow draft' : 'Leader performing workflow draft sync')
|
||||
const postParams = getPostParams()
|
||||
|
||||
if (postParams) {
|
||||
@ -120,17 +147,30 @@ export const useNodesSyncDraft = () => {
|
||||
setSyncWorkflowDraftHash,
|
||||
setDraftUpdatedAt,
|
||||
} = workflowStore.getState()
|
||||
|
||||
// Add force_upload parameter if needed
|
||||
const finalParams = {
|
||||
...postParams.params,
|
||||
...(forceUpload && { force_upload: true }),
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await syncWorkflowDraft(postParams)
|
||||
const res = await syncWorkflowDraft({
|
||||
url: postParams.url,
|
||||
params: finalParams,
|
||||
})
|
||||
setSyncWorkflowDraftHash(res.hash)
|
||||
setDraftUpdatedAt(res.updated_at)
|
||||
callback?.onSuccess?.()
|
||||
}
|
||||
catch (error: any) {
|
||||
console.error('Leader failed to sync workflow draft:', error)
|
||||
if (error && error.json && !error.bodyUsed) {
|
||||
error.json().then((err: any) => {
|
||||
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError)
|
||||
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) {
|
||||
console.error('draft_workflow_not_sync', err)
|
||||
handleRefreshWorkflowDraft()
|
||||
}
|
||||
})
|
||||
}
|
||||
callback?.onError?.()
|
||||
|
||||
@ -25,6 +25,7 @@ import {
|
||||
import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store'
|
||||
import { createWorkflowSlice } from './store/workflow/workflow-slice'
|
||||
import WorkflowAppMain from './components/workflow-main'
|
||||
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
|
||||
|
||||
const WorkflowAppWithAdditionalContext = () => {
|
||||
const {
|
||||
@ -35,15 +36,20 @@ const WorkflowAppWithAdditionalContext = () => {
|
||||
const { isLoadingCurrentWorkspace, currentWorkspace } = useAppContext()
|
||||
|
||||
const nodesData = useMemo(() => {
|
||||
if (data)
|
||||
return initialNodes(data.graph.nodes, data.graph.edges)
|
||||
|
||||
if (data) {
|
||||
const processedNodes = initialNodes(data.graph.nodes, data.graph.edges)
|
||||
collaborationManager.setNodes([], processedNodes)
|
||||
return processedNodes
|
||||
}
|
||||
return []
|
||||
}, [data])
|
||||
const edgesData = useMemo(() => {
|
||||
if (data)
|
||||
return initialEdges(data.graph.edges, data.graph.nodes)
|
||||
|
||||
const edgesData = useMemo(() => {
|
||||
if (data) {
|
||||
const processedEdges = initialEdges(data.graph.edges, data.graph.nodes)
|
||||
collaborationManager.setEdges([], processedEdges)
|
||||
return processedEdges
|
||||
}
|
||||
return []
|
||||
}, [data])
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import {
|
||||
import produce from 'immer'
|
||||
import {
|
||||
useReactFlow,
|
||||
useStoreApi,
|
||||
useViewport,
|
||||
} from 'reactflow'
|
||||
import { useEventListener } from 'ahooks'
|
||||
@ -19,9 +18,9 @@ import CustomNode from './nodes'
|
||||
import CustomNoteNode from './note-node'
|
||||
import { CUSTOM_NOTE_NODE } from './note-node/constants'
|
||||
import { BlockEnum } from './types'
|
||||
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
|
||||
|
||||
const CandidateNode = () => {
|
||||
const store = useStoreApi()
|
||||
const reactflow = useReactFlow()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const candidateNode = useStore(s => s.candidateNode)
|
||||
@ -29,18 +28,15 @@ const CandidateNode = () => {
|
||||
const { zoom } = useViewport()
|
||||
const { handleNodeSelect } = useNodesInteractions()
|
||||
const { saveStateToHistory } = useWorkflowHistory()
|
||||
const collaborativeWorkflow = useCollaborativeWorkflow()
|
||||
|
||||
useEventListener('click', (e) => {
|
||||
const { candidateNode, mousePosition } = workflowStore.getState()
|
||||
|
||||
if (candidateNode) {
|
||||
e.preventDefault()
|
||||
const {
|
||||
getNodes,
|
||||
setNodes,
|
||||
} = store.getState()
|
||||
const { nodes, setNodes } = collaborativeWorkflow.getState()
|
||||
const { screenToFlowPosition } = reactflow
|
||||
const nodes = getNodes()
|
||||
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.push({
|
||||
|
||||
@ -0,0 +1,78 @@
|
||||
import type { FC } from 'react'
|
||||
import { useViewport } from 'reactflow'
|
||||
import type { CursorPosition, OnlineUser } from '@/app/components/workflow/collaboration/types'
|
||||
import { getUserColor } from '../utils/user-color'
|
||||
|
||||
type UserCursorsProps = {
|
||||
cursors: Record<string, CursorPosition>
|
||||
myUserId: string | null
|
||||
onlineUsers: OnlineUser[]
|
||||
}
|
||||
|
||||
const UserCursors: FC<UserCursorsProps> = ({
|
||||
cursors,
|
||||
myUserId,
|
||||
onlineUsers,
|
||||
}) => {
|
||||
const viewport = useViewport()
|
||||
|
||||
const convertToScreenCoordinates = (cursor: CursorPosition) => {
|
||||
// Convert world coordinates to screen coordinates using current viewport
|
||||
const screenX = cursor.x * viewport.zoom + viewport.x
|
||||
const screenY = cursor.y * viewport.zoom + viewport.y
|
||||
|
||||
return { x: screenX, y: screenY }
|
||||
}
|
||||
return (
|
||||
<>
|
||||
{Object.entries(cursors || {}).map(([userId, cursor]) => {
|
||||
if (userId === myUserId)
|
||||
return null
|
||||
|
||||
const userInfo = onlineUsers.find(user => user.user_id === userId)
|
||||
const userName = userInfo?.username || `User ${userId.slice(-4)}`
|
||||
const userColor = getUserColor(userId)
|
||||
const screenPos = convertToScreenCoordinates(cursor)
|
||||
|
||||
return (
|
||||
<div
|
||||
key={userId}
|
||||
className="pointer-events-none absolute z-[8] transition-all duration-150 ease-out"
|
||||
style={{
|
||||
left: screenPos.x,
|
||||
top: screenPos.y,
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="drop-shadow-md"
|
||||
>
|
||||
<path
|
||||
d="M5 3L5 15L8 11.5L11 16L13 15L10 10.5L14 10.5L5 3Z"
|
||||
fill={userColor}
|
||||
stroke="white"
|
||||
strokeWidth="1.5"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
<div
|
||||
className="absolute left-4 top-4 max-w-[120px] overflow-hidden text-ellipsis whitespace-nowrap rounded px-1.5 py-0.5 text-[11px] font-medium text-white shadow-sm"
|
||||
style={{
|
||||
backgroundColor: userColor,
|
||||
}}
|
||||
>
|
||||
{userName}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default UserCursors
|
||||
@ -0,0 +1,924 @@
|
||||
import { LoroDoc, UndoManager } from 'loro-crdt'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import { webSocketClient } from './websocket-manager'
|
||||
import { CRDTProvider } from './crdt-provider'
|
||||
import { EventEmitter } from './event-emitter'
|
||||
import type { Edge, Node } from '../../types'
|
||||
import type {
|
||||
CollaborationState,
|
||||
CursorPosition,
|
||||
NodePanelPresenceMap,
|
||||
NodePanelPresenceUser,
|
||||
OnlineUser,
|
||||
} from '../types/collaboration'
|
||||
|
||||
type NodePanelPresenceEventData = {
|
||||
nodeId: string
|
||||
action: 'open' | 'close'
|
||||
user: NodePanelPresenceUser
|
||||
clientId: string
|
||||
timestamp?: number
|
||||
}
|
||||
|
||||
export class CollaborationManager {
|
||||
private doc: LoroDoc | null = null
|
||||
private undoManager: UndoManager | null = null
|
||||
private provider: CRDTProvider | null = null
|
||||
private nodesMap: any = null
|
||||
private edgesMap: any = null
|
||||
private eventEmitter = new EventEmitter()
|
||||
private currentAppId: string | null = null
|
||||
private reactFlowStore: any = null
|
||||
private isLeader = false
|
||||
private leaderId: string | null = null
|
||||
private cursors: Record<string, CursorPosition> = {}
|
||||
private nodePanelPresence: NodePanelPresenceMap = {}
|
||||
private activeConnections = new Set<string>()
|
||||
private isUndoRedoInProgress = false
|
||||
|
||||
private getNodePanelPresenceSnapshot(): NodePanelPresenceMap {
|
||||
const snapshot: NodePanelPresenceMap = {}
|
||||
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
|
||||
snapshot[nodeId] = { ...viewers }
|
||||
})
|
||||
return snapshot
|
||||
}
|
||||
|
||||
private applyNodePanelPresenceUpdate(update: NodePanelPresenceEventData): void {
|
||||
const { nodeId, action, clientId, user, timestamp } = update
|
||||
|
||||
if (action === 'open') {
|
||||
// ensure a client only appears on a single node at a time
|
||||
Object.entries(this.nodePanelPresence).forEach(([id, viewers]) => {
|
||||
if (viewers[clientId]) {
|
||||
delete viewers[clientId]
|
||||
if (Object.keys(viewers).length === 0)
|
||||
delete this.nodePanelPresence[id]
|
||||
}
|
||||
})
|
||||
|
||||
if (!this.nodePanelPresence[nodeId])
|
||||
this.nodePanelPresence[nodeId] = {}
|
||||
|
||||
this.nodePanelPresence[nodeId][clientId] = {
|
||||
...user,
|
||||
clientId,
|
||||
timestamp: timestamp || Date.now(),
|
||||
}
|
||||
}
|
||||
else {
|
||||
const viewers = this.nodePanelPresence[nodeId]
|
||||
if (viewers) {
|
||||
delete viewers[clientId]
|
||||
if (Object.keys(viewers).length === 0)
|
||||
delete this.nodePanelPresence[nodeId]
|
||||
}
|
||||
}
|
||||
|
||||
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
|
||||
}
|
||||
|
||||
private cleanupNodePanelPresence(activeClientIds: Set<string>, activeUserIds: Set<string>): void {
|
||||
let hasChanges = false
|
||||
|
||||
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
|
||||
Object.keys(viewers).forEach((clientId) => {
|
||||
const viewer = viewers[clientId]
|
||||
const clientActive = activeClientIds.has(clientId)
|
||||
const userActive = viewer?.userId ? activeUserIds.has(viewer.userId) : false
|
||||
|
||||
if (!clientActive && !userActive) {
|
||||
delete viewers[clientId]
|
||||
hasChanges = true
|
||||
}
|
||||
})
|
||||
|
||||
if (Object.keys(viewers).length === 0)
|
||||
delete this.nodePanelPresence[nodeId]
|
||||
})
|
||||
|
||||
if (hasChanges)
|
||||
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
|
||||
}
|
||||
|
||||
init = (appId: string, reactFlowStore: any): void => {
|
||||
if (!reactFlowStore) {
|
||||
console.warn('CollaborationManager.init called without reactFlowStore, deferring to connect()')
|
||||
return
|
||||
}
|
||||
this.connect(appId, reactFlowStore)
|
||||
}
|
||||
|
||||
setNodes = (oldNodes: Node[], newNodes: Node[]): void => {
|
||||
if (!this.doc) return
|
||||
|
||||
// Don't track operations during undo/redo to prevent loops
|
||||
if (this.isUndoRedoInProgress) {
|
||||
console.log('Skipping setNodes during undo/redo')
|
||||
return
|
||||
}
|
||||
|
||||
console.log('Setting nodes with tracking')
|
||||
this.syncNodes(oldNodes, newNodes)
|
||||
this.doc.commit()
|
||||
}
|
||||
|
||||
setEdges = (oldEdges: Edge[], newEdges: Edge[]): void => {
|
||||
if (!this.doc) return
|
||||
|
||||
// Don't track operations during undo/redo to prevent loops
|
||||
if (this.isUndoRedoInProgress) {
|
||||
console.log('Skipping setEdges during undo/redo')
|
||||
return
|
||||
}
|
||||
|
||||
console.log('Setting edges with tracking')
|
||||
this.syncEdges(oldEdges, newEdges)
|
||||
this.doc.commit()
|
||||
}
|
||||
|
||||
destroy = (): void => {
|
||||
this.disconnect()
|
||||
}
|
||||
|
||||
async connect(appId: string, reactFlowStore?: any): Promise<string> {
|
||||
const connectionId = Math.random().toString(36).substring(2, 11)
|
||||
|
||||
this.activeConnections.add(connectionId)
|
||||
|
||||
if (this.currentAppId === appId && this.doc) {
|
||||
// Already connected to the same app, only update store if provided and we don't have one
|
||||
if (reactFlowStore && !this.reactFlowStore)
|
||||
this.reactFlowStore = reactFlowStore
|
||||
|
||||
return connectionId
|
||||
}
|
||||
|
||||
// Only disconnect if switching to a different app
|
||||
if (this.currentAppId && this.currentAppId !== appId)
|
||||
this.forceDisconnect()
|
||||
|
||||
this.currentAppId = appId
|
||||
// Only set store if provided
|
||||
if (reactFlowStore)
|
||||
this.reactFlowStore = reactFlowStore
|
||||
|
||||
const socket = webSocketClient.connect(appId)
|
||||
|
||||
// Setup event listeners BEFORE any other operations
|
||||
this.setupSocketEventListeners(socket)
|
||||
|
||||
this.doc = new LoroDoc()
|
||||
this.nodesMap = this.doc.getMap('nodes')
|
||||
this.edgesMap = this.doc.getMap('edges')
|
||||
|
||||
// Initialize UndoManager for collaborative undo/redo
|
||||
this.undoManager = new UndoManager(this.doc, {
|
||||
maxUndoSteps: 100,
|
||||
mergeInterval: 500, // Merge operations within 500ms
|
||||
excludeOriginPrefixes: [], // Don't exclude anything - let UndoManager track all local operations
|
||||
onPush: (isUndo, range, event) => {
|
||||
console.log('UndoManager onPush:', { isUndo, range, event })
|
||||
// Store current selection state when an operation is pushed
|
||||
const selectedNode = this.reactFlowStore?.getState().getNodes().find((n: Node) => n.data?.selected)
|
||||
|
||||
// Emit event to update UI button states when new operation is pushed
|
||||
setTimeout(() => {
|
||||
this.eventEmitter.emit('undoRedoStateChange', {
|
||||
canUndo: this.undoManager?.canUndo() || false,
|
||||
canRedo: this.undoManager?.canRedo() || false,
|
||||
})
|
||||
}, 0)
|
||||
|
||||
return {
|
||||
value: {
|
||||
selectedNodeId: selectedNode?.id || null,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
cursors: [],
|
||||
}
|
||||
},
|
||||
onPop: (isUndo, value, counterRange) => {
|
||||
console.log('UndoManager onPop:', { isUndo, value, counterRange })
|
||||
// Restore selection state when undoing/redoing
|
||||
if (value?.value && typeof value.value === 'object' && 'selectedNodeId' in value.value && this.reactFlowStore) {
|
||||
const selectedNodeId = (value.value as any).selectedNodeId
|
||||
if (selectedNodeId) {
|
||||
const { setNodes } = this.reactFlowStore.getState()
|
||||
const nodes = this.reactFlowStore.getState().getNodes()
|
||||
const newNodes = nodes.map((n: Node) => ({
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
selected: n.id === selectedNodeId,
|
||||
},
|
||||
}))
|
||||
setNodes(newNodes)
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
this.provider = new CRDTProvider(socket, this.doc)
|
||||
|
||||
this.setupSubscriptions()
|
||||
|
||||
// Force user_connect if already connected
|
||||
if (socket.connected)
|
||||
socket.emit('user_connect', { workflow_id: appId })
|
||||
|
||||
return connectionId
|
||||
}
|
||||
|
||||
disconnect = (connectionId?: string): void => {
|
||||
if (connectionId)
|
||||
this.activeConnections.delete(connectionId)
|
||||
|
||||
// Only disconnect when no more connections
|
||||
if (this.activeConnections.size === 0)
|
||||
this.forceDisconnect()
|
||||
}
|
||||
|
||||
private forceDisconnect = (): void => {
|
||||
if (this.currentAppId)
|
||||
webSocketClient.disconnect(this.currentAppId)
|
||||
|
||||
this.provider?.destroy()
|
||||
this.undoManager = null
|
||||
this.doc = null
|
||||
this.provider = null
|
||||
this.nodesMap = null
|
||||
this.edgesMap = null
|
||||
this.currentAppId = null
|
||||
this.reactFlowStore = null
|
||||
this.cursors = {}
|
||||
this.nodePanelPresence = {}
|
||||
this.isUndoRedoInProgress = false
|
||||
|
||||
// Only reset leader status when actually disconnecting
|
||||
const wasLeader = this.isLeader
|
||||
this.isLeader = false
|
||||
this.leaderId = null
|
||||
|
||||
if (wasLeader)
|
||||
this.eventEmitter.emit('leaderChange', false)
|
||||
|
||||
this.activeConnections.clear()
|
||||
this.eventEmitter.removeAllListeners()
|
||||
}
|
||||
|
||||
isConnected(): boolean {
|
||||
return this.currentAppId ? webSocketClient.isConnected(this.currentAppId) : false
|
||||
}
|
||||
|
||||
getNodes(): Node[] {
|
||||
return this.nodesMap ? Array.from(this.nodesMap.values()) : []
|
||||
}
|
||||
|
||||
getEdges(): Edge[] {
|
||||
return this.edgesMap ? Array.from(this.edgesMap.values()) : []
|
||||
}
|
||||
|
||||
emitCursorMove(position: CursorPosition): void {
|
||||
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
|
||||
|
||||
const socket = webSocketClient.getSocket(this.currentAppId)
|
||||
if (socket) {
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'mouseMove',
|
||||
userId: socket.id,
|
||||
data: { x: position.x, y: position.y },
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
emitSyncRequest(): void {
|
||||
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
|
||||
|
||||
const socket = webSocketClient.getSocket(this.currentAppId)
|
||||
if (socket) {
|
||||
console.log('Emitting sync request to leader')
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'syncRequest',
|
||||
data: { timestamp: Date.now() },
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
emitWorkflowUpdate(appId: string): void {
|
||||
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
|
||||
|
||||
const socket = webSocketClient.getSocket(this.currentAppId)
|
||||
if (socket) {
|
||||
console.log('Emitting Workflow update event')
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'workflowUpdate',
|
||||
data: { appId, timestamp: Date.now() },
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
emitNodePanelPresence(nodeId: string, isOpen: boolean, user: NodePanelPresenceUser): void {
|
||||
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
|
||||
|
||||
const socket = webSocketClient.getSocket(this.currentAppId)
|
||||
if (!socket || !nodeId || !user?.userId) return
|
||||
|
||||
const payload: NodePanelPresenceEventData = {
|
||||
nodeId,
|
||||
action: isOpen ? 'open' : 'close',
|
||||
user,
|
||||
clientId: socket.id as string,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'nodePanelPresence',
|
||||
data: payload,
|
||||
timestamp: payload.timestamp,
|
||||
})
|
||||
|
||||
this.applyNodePanelPresenceUpdate(payload)
|
||||
}
|
||||
|
||||
onSyncRequest(callback: () => void): () => void {
|
||||
return this.eventEmitter.on('syncRequest', callback)
|
||||
}
|
||||
|
||||
onStateChange(callback: (state: Partial<CollaborationState>) => void): () => void {
|
||||
return this.eventEmitter.on('stateChange', callback)
|
||||
}
|
||||
|
||||
onCursorUpdate(callback: (cursors: Record<string, CursorPosition>) => void): () => void {
|
||||
return this.eventEmitter.on('cursors', callback)
|
||||
}
|
||||
|
||||
onOnlineUsersUpdate(callback: (users: OnlineUser[]) => void): () => void {
|
||||
return this.eventEmitter.on('onlineUsers', callback)
|
||||
}
|
||||
|
||||
onWorkflowUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
|
||||
return this.eventEmitter.on('workflowUpdate', callback)
|
||||
}
|
||||
|
||||
onVarsAndFeaturesUpdate(callback: (update: any) => void): () => void {
|
||||
return this.eventEmitter.on('varsAndFeaturesUpdate', callback)
|
||||
}
|
||||
|
||||
onAppStateUpdate(callback: (update: any) => void): () => void {
|
||||
return this.eventEmitter.on('appStateUpdate', callback)
|
||||
}
|
||||
|
||||
onMcpServerUpdate(callback: (update: any) => void): () => void {
|
||||
return this.eventEmitter.on('mcpServerUpdate', callback)
|
||||
}
|
||||
|
||||
onNodePanelPresenceUpdate(callback: (presence: NodePanelPresenceMap) => void): () => void {
|
||||
const off = this.eventEmitter.on('nodePanelPresence', callback)
|
||||
callback(this.getNodePanelPresenceSnapshot())
|
||||
return off
|
||||
}
|
||||
|
||||
onLeaderChange(callback: (isLeader: boolean) => void): () => void {
|
||||
return this.eventEmitter.on('leaderChange', callback)
|
||||
}
|
||||
|
||||
onCommentsUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
|
||||
return this.eventEmitter.on('commentsUpdate', callback)
|
||||
}
|
||||
|
||||
emitCommentsUpdate(appId: string): void {
|
||||
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
|
||||
|
||||
const socket = webSocketClient.getSocket(this.currentAppId)
|
||||
if (socket) {
|
||||
console.log('Emitting Comments update event')
|
||||
socket.emit('collaboration_event', {
|
||||
type: 'commentsUpdate',
|
||||
data: { appId, timestamp: Date.now() },
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
onUndoRedoStateChange(callback: (state: { canUndo: boolean; canRedo: boolean }) => void): () => void {
|
||||
return this.eventEmitter.on('undoRedoStateChange', callback)
|
||||
}
|
||||
|
||||
getLeaderId(): string | null {
|
||||
return this.leaderId
|
||||
}
|
||||
|
||||
getIsLeader(): boolean {
|
||||
return this.isLeader
|
||||
}
|
||||
|
||||
// Collaborative undo/redo methods
|
||||
undo(): boolean {
|
||||
if (!this.undoManager) {
|
||||
console.log('UndoManager not initialized')
|
||||
return false
|
||||
}
|
||||
|
||||
const canUndo = this.undoManager.canUndo()
|
||||
console.log('Can undo:', canUndo)
|
||||
|
||||
if (canUndo) {
|
||||
this.isUndoRedoInProgress = true
|
||||
const result = this.undoManager.undo()
|
||||
|
||||
// After undo, manually update React state from CRDT without triggering collaboration
|
||||
if (result && this.reactFlowStore) {
|
||||
requestAnimationFrame(() => {
|
||||
// Get ReactFlow's native setters, not the collaborative ones
|
||||
const state = this.reactFlowStore.getState()
|
||||
const updatedNodes = Array.from(this.nodesMap.values())
|
||||
const updatedEdges = Array.from(this.edgesMap.values())
|
||||
console.log('Manually updating React state after undo')
|
||||
|
||||
// Call ReactFlow's native setters directly to avoid triggering collaboration
|
||||
state.setNodes(updatedNodes)
|
||||
state.setEdges(updatedEdges)
|
||||
|
||||
this.isUndoRedoInProgress = false
|
||||
|
||||
// Emit event to update UI button states
|
||||
this.eventEmitter.emit('undoRedoStateChange', {
|
||||
canUndo: this.undoManager?.canUndo() || false,
|
||||
canRedo: this.undoManager?.canRedo() || false,
|
||||
})
|
||||
})
|
||||
}
|
||||
else {
|
||||
this.isUndoRedoInProgress = false
|
||||
}
|
||||
|
||||
console.log('Undo result:', result)
|
||||
return result
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
redo(): boolean {
|
||||
if (!this.undoManager) {
|
||||
console.log('RedoManager not initialized')
|
||||
return false
|
||||
}
|
||||
|
||||
const canRedo = this.undoManager.canRedo()
|
||||
console.log('Can redo:', canRedo)
|
||||
|
||||
if (canRedo) {
|
||||
this.isUndoRedoInProgress = true
|
||||
const result = this.undoManager.redo()
|
||||
|
||||
// After redo, manually update React state from CRDT without triggering collaboration
|
||||
if (result && this.reactFlowStore) {
|
||||
requestAnimationFrame(() => {
|
||||
// Get ReactFlow's native setters, not the collaborative ones
|
||||
const state = this.reactFlowStore.getState()
|
||||
const updatedNodes = Array.from(this.nodesMap.values())
|
||||
const updatedEdges = Array.from(this.edgesMap.values())
|
||||
console.log('Manually updating React state after redo')
|
||||
|
||||
// Call ReactFlow's native setters directly to avoid triggering collaboration
|
||||
state.setNodes(updatedNodes)
|
||||
state.setEdges(updatedEdges)
|
||||
|
||||
this.isUndoRedoInProgress = false
|
||||
|
||||
// Emit event to update UI button states
|
||||
this.eventEmitter.emit('undoRedoStateChange', {
|
||||
canUndo: this.undoManager?.canUndo() || false,
|
||||
canRedo: this.undoManager?.canRedo() || false,
|
||||
})
|
||||
})
|
||||
}
|
||||
else {
|
||||
this.isUndoRedoInProgress = false
|
||||
}
|
||||
|
||||
console.log('Redo result:', result)
|
||||
return result
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
canUndo(): boolean {
|
||||
if (!this.undoManager) return false
|
||||
return this.undoManager.canUndo()
|
||||
}
|
||||
|
||||
canRedo(): boolean {
|
||||
if (!this.undoManager) return false
|
||||
return this.undoManager.canRedo()
|
||||
}
|
||||
|
||||
clearUndoStack(): void {
|
||||
if (!this.undoManager) return
|
||||
this.undoManager.clear()
|
||||
}
|
||||
|
||||
debugLeaderStatus(): void {
|
||||
console.log('=== Leader Status Debug ===')
|
||||
console.log('Current leader status:', this.isLeader)
|
||||
console.log('Current leader ID:', this.leaderId)
|
||||
console.log('Active connections:', this.activeConnections.size)
|
||||
console.log('Connected:', this.isConnected())
|
||||
console.log('Current app ID:', this.currentAppId)
|
||||
console.log('Has ReactFlow store:', !!this.reactFlowStore)
|
||||
console.log('========================')
|
||||
}
|
||||
|
||||
private syncNodes(oldNodes: Node[], newNodes: Node[]): void {
|
||||
if (!this.nodesMap || !this.doc) return
|
||||
|
||||
const oldNodesMap = new Map(oldNodes.map(node => [node.id, node]))
|
||||
const newNodesMap = new Map(newNodes.map(node => [node.id, node]))
|
||||
const syncDataAllowList = new Set(['_children'])
|
||||
const shouldSyncDataKey = (key: string) => (syncDataAllowList.has(key) || !key.startsWith('_')) && key !== 'selected'
|
||||
|
||||
// Delete removed nodes
|
||||
oldNodes.forEach((oldNode) => {
|
||||
if (!newNodesMap.has(oldNode.id))
|
||||
this.nodesMap.delete(oldNode.id)
|
||||
})
|
||||
|
||||
// Add or update nodes with fine-grained sync for data properties
|
||||
const copyOptionalNodeProps = (source: Node, target: any) => {
|
||||
const optionalProps: Array<keyof Node | keyof any> = [
|
||||
'parentId',
|
||||
'positionAbsolute',
|
||||
'extent',
|
||||
'zIndex',
|
||||
'draggable',
|
||||
'selectable',
|
||||
'dragHandle',
|
||||
'dragging',
|
||||
'connectable',
|
||||
'expandParent',
|
||||
'focusable',
|
||||
'hidden',
|
||||
'style',
|
||||
'className',
|
||||
'ariaLabel',
|
||||
'markerStart',
|
||||
'markerEnd',
|
||||
'resizing',
|
||||
'deletable',
|
||||
]
|
||||
|
||||
optionalProps.forEach((prop) => {
|
||||
const value = (source as any)[prop]
|
||||
if (value === undefined) {
|
||||
if (prop in target)
|
||||
delete target[prop]
|
||||
return
|
||||
}
|
||||
|
||||
if (value !== null && typeof value === 'object')
|
||||
target[prop as string] = JSON.parse(JSON.stringify(value))
|
||||
else
|
||||
target[prop as string] = value
|
||||
})
|
||||
}
|
||||
|
||||
newNodes.forEach((newNode) => {
|
||||
const oldNode = oldNodesMap.get(newNode.id)
|
||||
|
||||
if (!oldNode) {
|
||||
// New node - create as nested structure
|
||||
const nodeData: any = {
|
||||
id: newNode.id,
|
||||
type: newNode.type,
|
||||
position: { ...newNode.position },
|
||||
width: newNode.width,
|
||||
height: newNode.height,
|
||||
sourcePosition: newNode.sourcePosition,
|
||||
targetPosition: newNode.targetPosition,
|
||||
data: {},
|
||||
}
|
||||
|
||||
copyOptionalNodeProps(newNode, nodeData)
|
||||
|
||||
// Clone data properties, excluding private ones
|
||||
Object.entries(newNode.data).forEach(([key, value]) => {
|
||||
if (shouldSyncDataKey(key) && value !== undefined)
|
||||
nodeData.data[key] = JSON.parse(JSON.stringify(value))
|
||||
})
|
||||
|
||||
this.nodesMap.set(newNode.id, nodeData)
|
||||
}
|
||||
else {
|
||||
// Get existing node from CRDT
|
||||
const existingNode = this.nodesMap.get(newNode.id)
|
||||
|
||||
if (existingNode) {
|
||||
// Create a deep copy to modify
|
||||
const updatedNode = JSON.parse(JSON.stringify(existingNode))
|
||||
|
||||
// Update position only if changed
|
||||
if (oldNode.position.x !== newNode.position.x || oldNode.position.y !== newNode.position.y)
|
||||
updatedNode.position = { ...newNode.position }
|
||||
|
||||
// Update dimensions only if changed
|
||||
if (oldNode.width !== newNode.width)
|
||||
updatedNode.width = newNode.width
|
||||
|
||||
if (oldNode.height !== newNode.height)
|
||||
updatedNode.height = newNode.height
|
||||
|
||||
// Ensure optional node props stay in sync
|
||||
copyOptionalNodeProps(newNode, updatedNode)
|
||||
|
||||
// Ensure data object exists
|
||||
if (!updatedNode.data)
|
||||
updatedNode.data = {}
|
||||
|
||||
// Fine-grained update of data properties
|
||||
const oldData = oldNode.data || {}
|
||||
const newData = newNode.data || {}
|
||||
|
||||
// Only update changed properties in data
|
||||
Object.entries(newData).forEach(([key, value]) => {
|
||||
if (shouldSyncDataKey(key)) {
|
||||
const oldValue = (oldData as any)[key]
|
||||
if (!isEqual(oldValue, value))
|
||||
updatedNode.data[key] = JSON.parse(JSON.stringify(value))
|
||||
}
|
||||
})
|
||||
|
||||
// Remove deleted properties from data
|
||||
Object.keys(oldData).forEach((key) => {
|
||||
if (shouldSyncDataKey(key) && !(key in newData))
|
||||
delete updatedNode.data[key]
|
||||
})
|
||||
|
||||
// Only update in CRDT if something actually changed
|
||||
if (!isEqual(existingNode, updatedNode))
|
||||
this.nodesMap.set(newNode.id, updatedNode)
|
||||
}
|
||||
else {
|
||||
// Node exists locally but not in CRDT yet
|
||||
const nodeData: any = {
|
||||
id: newNode.id,
|
||||
type: newNode.type,
|
||||
position: { ...newNode.position },
|
||||
width: newNode.width,
|
||||
height: newNode.height,
|
||||
sourcePosition: newNode.sourcePosition,
|
||||
targetPosition: newNode.targetPosition,
|
||||
data: {},
|
||||
}
|
||||
|
||||
copyOptionalNodeProps(newNode, nodeData)
|
||||
|
||||
Object.entries(newNode.data).forEach(([key, value]) => {
|
||||
if (shouldSyncDataKey(key) && value !== undefined)
|
||||
nodeData.data[key] = JSON.parse(JSON.stringify(value))
|
||||
})
|
||||
|
||||
this.nodesMap.set(newNode.id, nodeData)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private syncEdges(oldEdges: Edge[], newEdges: Edge[]): void {
|
||||
if (!this.edgesMap) return
|
||||
|
||||
const oldEdgesMap = new Map(oldEdges.map(edge => [edge.id, edge]))
|
||||
const newEdgesMap = new Map(newEdges.map(edge => [edge.id, edge]))
|
||||
|
||||
oldEdges.forEach((oldEdge) => {
|
||||
if (!newEdgesMap.has(oldEdge.id))
|
||||
this.edgesMap.delete(oldEdge.id)
|
||||
})
|
||||
|
||||
newEdges.forEach((newEdge) => {
|
||||
const oldEdge = oldEdgesMap.get(newEdge.id)
|
||||
if (!oldEdge) {
|
||||
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
|
||||
this.edgesMap.set(newEdge.id, clonedEdge)
|
||||
}
|
||||
else if (!isEqual(oldEdge, newEdge)) {
|
||||
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
|
||||
this.edgesMap.set(newEdge.id, clonedEdge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private setupSubscriptions(): void {
|
||||
this.nodesMap?.subscribe((event: any) => {
|
||||
console.log('nodesMap subscription event:', event)
|
||||
if (event.by === 'import' && this.reactFlowStore) {
|
||||
// Don't update React nodes during undo/redo to prevent loops
|
||||
if (this.isUndoRedoInProgress) {
|
||||
console.log('Skipping nodes subscription update during undo/redo')
|
||||
return
|
||||
}
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
// Get ReactFlow's native setters, not the collaborative ones
|
||||
const state = this.reactFlowStore.getState()
|
||||
const previousNodes: Node[] = state.getNodes()
|
||||
const selectedIds = new Set(
|
||||
previousNodes
|
||||
.filter(node => node.data?.selected)
|
||||
.map(node => node.id),
|
||||
)
|
||||
|
||||
const updatedNodes = Array
|
||||
.from(this.nodesMap.values())
|
||||
.map((node: Node) => {
|
||||
const clonedNode: Node = {
|
||||
...node,
|
||||
data: {
|
||||
...(node.data || {}),
|
||||
},
|
||||
}
|
||||
|
||||
if (selectedIds.has(clonedNode.id))
|
||||
clonedNode.data.selected = true
|
||||
|
||||
return clonedNode
|
||||
})
|
||||
|
||||
console.log('Updating React nodes from subscription')
|
||||
|
||||
// Call ReactFlow's native setter directly to avoid triggering collaboration
|
||||
state.setNodes(updatedNodes)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
this.edgesMap?.subscribe((event: any) => {
|
||||
console.log('edgesMap subscription event:', event)
|
||||
if (event.by === 'import' && this.reactFlowStore) {
|
||||
// Don't update React edges during undo/redo to prevent loops
|
||||
if (this.isUndoRedoInProgress) {
|
||||
console.log('Skipping edges subscription update during undo/redo')
|
||||
return
|
||||
}
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
// Get ReactFlow's native setters, not the collaborative ones
|
||||
const state = this.reactFlowStore.getState()
|
||||
const updatedEdges = Array.from(this.edgesMap.values())
|
||||
console.log('Updating React edges from subscription')
|
||||
|
||||
// Call ReactFlow's native setter directly to avoid triggering collaboration
|
||||
state.setEdges(updatedEdges)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private setupSocketEventListeners(socket: any): void {
|
||||
console.log('Setting up socket event listeners for collaboration')
|
||||
|
||||
socket.on('collaboration_update', (update: any) => {
|
||||
if (update.type === 'mouseMove') {
|
||||
// Update cursor state for this user
|
||||
this.cursors[update.userId] = {
|
||||
x: update.data.x,
|
||||
y: update.data.y,
|
||||
userId: update.userId,
|
||||
timestamp: update.timestamp,
|
||||
}
|
||||
|
||||
this.eventEmitter.emit('cursors', { ...this.cursors })
|
||||
}
|
||||
else if (update.type === 'varsAndFeaturesUpdate') {
|
||||
console.log('Processing varsAndFeaturesUpdate event:', update)
|
||||
this.eventEmitter.emit('varsAndFeaturesUpdate', update)
|
||||
}
|
||||
else if (update.type === 'appStateUpdate') {
|
||||
console.log('Processing appStateUpdate event:', update)
|
||||
this.eventEmitter.emit('appStateUpdate', update)
|
||||
}
|
||||
else if (update.type === 'mcpServerUpdate') {
|
||||
console.log('Processing mcpServerUpdate event:', update)
|
||||
this.eventEmitter.emit('mcpServerUpdate', update)
|
||||
}
|
||||
else if (update.type === 'workflowUpdate') {
|
||||
console.log('Processing workflowUpdate event:', update)
|
||||
this.eventEmitter.emit('workflowUpdate', update.data)
|
||||
}
|
||||
else if (update.type === 'commentsUpdate') {
|
||||
console.log('Processing commentsUpdate event:', update)
|
||||
this.eventEmitter.emit('commentsUpdate', update.data)
|
||||
}
|
||||
else if (update.type === 'nodePanelPresence') {
|
||||
console.log('Processing nodePanelPresence event:', update)
|
||||
this.applyNodePanelPresenceUpdate(update.data as NodePanelPresenceEventData)
|
||||
}
|
||||
else if (update.type === 'syncRequest') {
|
||||
console.log('Received sync request from another user')
|
||||
// Only process if we are the leader
|
||||
if (this.isLeader) {
|
||||
console.log('Leader received sync request, triggering sync')
|
||||
this.eventEmitter.emit('syncRequest', {})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
socket.on('online_users', (data: { users: OnlineUser[]; leader?: string }) => {
|
||||
try {
|
||||
if (!data || !Array.isArray(data.users)) {
|
||||
console.warn('Invalid online_users data structure:', data)
|
||||
return
|
||||
}
|
||||
|
||||
const onlineUserIds = new Set(data.users.map((user: OnlineUser) => user.user_id))
|
||||
const onlineClientIds = new Set(
|
||||
data.users
|
||||
.map((user: OnlineUser) => user.sid)
|
||||
.filter((sid): sid is string => typeof sid === 'string' && sid.length > 0),
|
||||
)
|
||||
|
||||
// Remove cursors for offline users
|
||||
Object.keys(this.cursors).forEach((userId) => {
|
||||
if (!onlineUserIds.has(userId))
|
||||
delete this.cursors[userId]
|
||||
})
|
||||
|
||||
this.cleanupNodePanelPresence(onlineClientIds, onlineUserIds)
|
||||
|
||||
// Update leader information
|
||||
if (data.leader && typeof data.leader === 'string')
|
||||
this.leaderId = data.leader
|
||||
|
||||
this.eventEmitter.emit('onlineUsers', data.users)
|
||||
this.eventEmitter.emit('cursors', { ...this.cursors })
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error processing online_users update:', error)
|
||||
}
|
||||
})
|
||||
|
||||
socket.on('status', (data: any) => {
|
||||
try {
|
||||
if (!data || typeof data.isLeader !== 'boolean') {
|
||||
console.warn('Invalid status data:', data)
|
||||
return
|
||||
}
|
||||
|
||||
const wasLeader = this.isLeader
|
||||
this.isLeader = data.isLeader
|
||||
|
||||
if (wasLeader !== this.isLeader)
|
||||
this.eventEmitter.emit('leaderChange', this.isLeader)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error processing status update:', error)
|
||||
}
|
||||
})
|
||||
|
||||
socket.on('status', (data: { isLeader: boolean }) => {
|
||||
if (this.isLeader !== data.isLeader) {
|
||||
this.isLeader = data.isLeader
|
||||
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
|
||||
this.eventEmitter.emit('leaderChange', this.isLeader)
|
||||
}
|
||||
})
|
||||
|
||||
socket.on('status', (data: { isLeader: boolean }) => {
|
||||
if (this.isLeader !== data.isLeader) {
|
||||
this.isLeader = data.isLeader
|
||||
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
|
||||
this.eventEmitter.emit('leaderChange', this.isLeader)
|
||||
}
|
||||
})
|
||||
|
||||
socket.on('connect', () => {
|
||||
console.log('WebSocket connected successfully')
|
||||
this.eventEmitter.emit('stateChange', { isConnected: true })
|
||||
})
|
||||
|
||||
socket.on('disconnect', (reason: string) => {
|
||||
console.log('WebSocket disconnected:', reason)
|
||||
this.cursors = {}
|
||||
this.isLeader = false
|
||||
this.leaderId = null
|
||||
this.eventEmitter.emit('stateChange', { isConnected: false })
|
||||
this.eventEmitter.emit('cursors', {})
|
||||
})
|
||||
|
||||
socket.on('connect_error', (error: any) => {
|
||||
console.error('WebSocket connection error:', error)
|
||||
this.eventEmitter.emit('stateChange', { isConnected: false, error: error.message })
|
||||
})
|
||||
|
||||
socket.on('error', (error: any) => {
|
||||
console.error('WebSocket error:', error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const collaborationManager = new CollaborationManager()
|
||||
@ -0,0 +1,36 @@
|
||||
import type { LoroDoc } from 'loro-crdt'
|
||||
import type { Socket } from 'socket.io-client'
|
||||
|
||||
export class CRDTProvider {
|
||||
private doc: LoroDoc
|
||||
private socket: Socket
|
||||
|
||||
constructor(socket: Socket, doc: LoroDoc) {
|
||||
this.socket = socket
|
||||
this.doc = doc
|
||||
this.setupEventListeners()
|
||||
}
|
||||
|
||||
private setupEventListeners(): void {
|
||||
this.doc.subscribe((event: any) => {
|
||||
if (event.by === 'local') {
|
||||
const update = this.doc.export({ mode: 'update' })
|
||||
this.socket.emit('graph_event', update)
|
||||
}
|
||||
})
|
||||
|
||||
this.socket.on('graph_update', (updateData: Uint8Array) => {
|
||||
try {
|
||||
const data = new Uint8Array(updateData)
|
||||
this.doc.import(data)
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error importing graph update:', error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
destroy(): void {
|
||||
this.socket.off('graph_update')
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,49 @@
|
||||
export type EventHandler<T = any> = (data: T) => void
|
||||
|
||||
export class EventEmitter {
|
||||
private events: Map<string, Set<EventHandler>> = new Map()
|
||||
|
||||
on<T = any>(event: string, handler: EventHandler<T>): () => void {
|
||||
if (!this.events.has(event))
|
||||
this.events.set(event, new Set())
|
||||
|
||||
this.events.get(event)!.add(handler)
|
||||
|
||||
return () => this.off(event, handler)
|
||||
}
|
||||
|
||||
off<T = any>(event: string, handler?: EventHandler<T>): void {
|
||||
if (!this.events.has(event)) return
|
||||
|
||||
const handlers = this.events.get(event)!
|
||||
if (handler)
|
||||
handlers.delete(handler)
|
||||
else
|
||||
handlers.clear()
|
||||
|
||||
if (handlers.size === 0)
|
||||
this.events.delete(event)
|
||||
}
|
||||
|
||||
emit<T = any>(event: string, data: T): void {
|
||||
if (!this.events.has(event)) return
|
||||
|
||||
const handlers = this.events.get(event)!
|
||||
handlers.forEach((handler) => {
|
||||
try {
|
||||
handler(data)
|
||||
}
|
||||
catch (error) {
|
||||
console.error(`Error in event handler for ${event}:`, error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
removeAllListeners(): void {
|
||||
this.events.clear()
|
||||
}
|
||||
|
||||
getListenerCount(event: string): number {
|
||||
return this.events.get(event)?.size || 0
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,125 @@
|
||||
import type { Socket } from 'socket.io-client'
|
||||
import { io } from 'socket.io-client'
|
||||
import type { DebugInfo, WebSocketConfig } from '../types/websocket'
|
||||
|
||||
export class WebSocketClient {
|
||||
private connections: Map<string, Socket> = new Map()
|
||||
private connecting: Set<string> = new Set()
|
||||
private config: WebSocketConfig
|
||||
|
||||
constructor(config: WebSocketConfig = {}) {
|
||||
const inferUrl = () => {
|
||||
if (typeof window === 'undefined')
|
||||
return 'ws://localhost:5001'
|
||||
const scheme = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||
return `${scheme}//${window.location.host}`
|
||||
}
|
||||
this.config = {
|
||||
url: config.url || process.env.NEXT_PUBLIC_SOCKET_URL || inferUrl(),
|
||||
transports: config.transports || ['websocket'],
|
||||
withCredentials: config.withCredentials !== false,
|
||||
...config,
|
||||
}
|
||||
}
|
||||
|
||||
connect(appId: string): Socket {
|
||||
const existingSocket = this.connections.get(appId)
|
||||
if (existingSocket?.connected)
|
||||
return existingSocket
|
||||
|
||||
if (this.connecting.has(appId)) {
|
||||
const pendingSocket = this.connections.get(appId)
|
||||
if (pendingSocket)
|
||||
return pendingSocket
|
||||
}
|
||||
|
||||
if (existingSocket && !existingSocket.connected) {
|
||||
existingSocket.disconnect()
|
||||
this.connections.delete(appId)
|
||||
}
|
||||
|
||||
this.connecting.add(appId)
|
||||
|
||||
const authToken = localStorage.getItem('console_token')
|
||||
const socket = io(this.config.url!, {
|
||||
path: '/socket.io',
|
||||
transports: this.config.transports,
|
||||
auth: { token: authToken },
|
||||
withCredentials: this.config.withCredentials,
|
||||
})
|
||||
|
||||
this.connections.set(appId, socket)
|
||||
this.setupBaseEventListeners(socket, appId)
|
||||
|
||||
return socket
|
||||
}
|
||||
|
||||
disconnect(appId?: string): void {
|
||||
if (appId) {
|
||||
const socket = this.connections.get(appId)
|
||||
if (socket) {
|
||||
socket.disconnect()
|
||||
this.connections.delete(appId)
|
||||
this.connecting.delete(appId)
|
||||
}
|
||||
}
|
||||
else {
|
||||
this.connections.forEach(socket => socket.disconnect())
|
||||
this.connections.clear()
|
||||
this.connecting.clear()
|
||||
}
|
||||
}
|
||||
|
||||
getSocket(appId: string): Socket | null {
|
||||
return this.connections.get(appId) || null
|
||||
}
|
||||
|
||||
isConnected(appId: string): boolean {
|
||||
return this.connections.get(appId)?.connected || false
|
||||
}
|
||||
|
||||
getConnectedApps(): string[] {
|
||||
const connectedApps: string[] = []
|
||||
this.connections.forEach((socket, appId) => {
|
||||
if (socket.connected)
|
||||
connectedApps.push(appId)
|
||||
})
|
||||
return connectedApps
|
||||
}
|
||||
|
||||
getDebugInfo(): DebugInfo {
|
||||
const info: DebugInfo = {}
|
||||
this.connections.forEach((socket, appId) => {
|
||||
info[appId] = {
|
||||
connected: socket.connected,
|
||||
connecting: this.connecting.has(appId),
|
||||
socketId: socket.id,
|
||||
}
|
||||
})
|
||||
return info
|
||||
}
|
||||
|
||||
private setupBaseEventListeners(socket: Socket, appId: string): void {
|
||||
socket.on('connect', () => {
|
||||
this.connecting.delete(appId)
|
||||
socket.emit('user_connect', { workflow_id: appId })
|
||||
})
|
||||
|
||||
socket.on('disconnect', () => {
|
||||
this.connecting.delete(appId)
|
||||
})
|
||||
|
||||
socket.on('connect_error', () => {
|
||||
this.connecting.delete(appId)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const webSocketClient = new WebSocketClient()
|
||||
|
||||
export const fetchAppsOnlineUsers = async (appIds: string[]) => {
|
||||
const response = await fetch(`/api/online-users?${new URLSearchParams({
|
||||
app_ids: appIds.join(','),
|
||||
})}`)
|
||||
return response.json()
|
||||
}
|
||||
@ -0,0 +1,92 @@
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import type { ReactFlowInstance } from 'reactflow'
|
||||
import { collaborationManager } from '../core/collaboration-manager'
|
||||
import { CursorService } from '../services/cursor-service'
|
||||
import type { CollaborationState } from '../types/collaboration'
|
||||
|
||||
export function useCollaboration(appId: string, reactFlowStore?: any) {
|
||||
const [state, setState] = useState<Partial<CollaborationState & { isLeader: boolean }>>({
|
||||
isConnected: false,
|
||||
onlineUsers: [],
|
||||
cursors: {},
|
||||
nodePanelPresence: {},
|
||||
isLeader: false,
|
||||
})
|
||||
|
||||
const cursorServiceRef = useRef<CursorService | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (!appId) return
|
||||
|
||||
let connectionId: string | null = null
|
||||
|
||||
if (!cursorServiceRef.current)
|
||||
cursorServiceRef.current = new CursorService()
|
||||
|
||||
const initCollaboration = async () => {
|
||||
connectionId = await collaborationManager.connect(appId, reactFlowStore)
|
||||
setState((prev: any) => ({ ...prev, appId, isConnected: collaborationManager.isConnected() }))
|
||||
}
|
||||
|
||||
initCollaboration()
|
||||
|
||||
const unsubscribeStateChange = collaborationManager.onStateChange((newState: any) => {
|
||||
console.log('Collaboration state change:', newState)
|
||||
setState((prev: any) => ({ ...prev, ...newState }))
|
||||
})
|
||||
|
||||
const unsubscribeCursors = collaborationManager.onCursorUpdate((cursors: any) => {
|
||||
setState((prev: any) => ({ ...prev, cursors }))
|
||||
})
|
||||
|
||||
const unsubscribeUsers = collaborationManager.onOnlineUsersUpdate((users: any) => {
|
||||
console.log('Online users update:', users)
|
||||
setState((prev: any) => ({ ...prev, onlineUsers: users }))
|
||||
})
|
||||
|
||||
const unsubscribeNodePanelPresence = collaborationManager.onNodePanelPresenceUpdate((presence) => {
|
||||
setState((prev: any) => ({ ...prev, nodePanelPresence: presence }))
|
||||
})
|
||||
|
||||
const unsubscribeLeaderChange = collaborationManager.onLeaderChange((isLeader: boolean) => {
|
||||
console.log('Leader status changed:', isLeader)
|
||||
setState((prev: any) => ({ ...prev, isLeader }))
|
||||
})
|
||||
|
||||
return () => {
|
||||
unsubscribeStateChange()
|
||||
unsubscribeCursors()
|
||||
unsubscribeUsers()
|
||||
unsubscribeNodePanelPresence()
|
||||
unsubscribeLeaderChange()
|
||||
cursorServiceRef.current?.stopTracking()
|
||||
if (connectionId)
|
||||
collaborationManager.disconnect(connectionId)
|
||||
}
|
||||
}, [appId, reactFlowStore])
|
||||
|
||||
const startCursorTracking = (containerRef: React.RefObject<HTMLElement>, reactFlowInstance?: ReactFlowInstance) => {
|
||||
if (cursorServiceRef.current) {
|
||||
cursorServiceRef.current.startTracking(containerRef, (position) => {
|
||||
collaborationManager.emitCursorMove(position)
|
||||
}, reactFlowInstance)
|
||||
}
|
||||
}
|
||||
|
||||
const stopCursorTracking = () => {
|
||||
cursorServiceRef.current?.stopTracking()
|
||||
}
|
||||
|
||||
const result = {
|
||||
isConnected: state.isConnected || false,
|
||||
onlineUsers: state.onlineUsers || [],
|
||||
cursors: state.cursors || {},
|
||||
nodePanelPresence: state.nodePanelPresence || {},
|
||||
isLeader: state.isLeader || false,
|
||||
leaderId: collaborationManager.getLeaderId(),
|
||||
startCursorTracking,
|
||||
stopCursorTracking,
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
5
web/app/components/workflow/collaboration/index.ts
Normal file
5
web/app/components/workflow/collaboration/index.ts
Normal file
@ -0,0 +1,5 @@
|
||||
export { collaborationManager } from './core/collaboration-manager'
|
||||
export { webSocketClient, fetchAppsOnlineUsers } from './core/websocket-manager'
|
||||
export { CursorService } from './services/cursor-service'
|
||||
export { useCollaboration } from './hooks/use-collaboration'
|
||||
export * from './types'
|
||||
@ -0,0 +1,88 @@
|
||||
import type { RefObject } from 'react'
|
||||
import type { CursorPosition } from '../types/collaboration'
|
||||
import type { ReactFlowInstance } from 'reactflow'
|
||||
|
||||
const CURSOR_MIN_MOVE_DISTANCE = 10
|
||||
const CURSOR_THROTTLE_MS = 500
|
||||
|
||||
export class CursorService {
|
||||
private containerRef: RefObject<HTMLElement> | null = null
|
||||
private reactFlowInstance: ReactFlowInstance | null = null
|
||||
private isTracking = false
|
||||
private onCursorUpdate: ((cursors: Record<string, CursorPosition>) => void) | null = null
|
||||
private onEmitPosition: ((position: CursorPosition) => void) | null = null
|
||||
private lastEmitTime = 0
|
||||
private lastPosition: { x: number; y: number } | null = null
|
||||
|
||||
startTracking(
|
||||
containerRef: RefObject<HTMLElement>,
|
||||
onEmitPosition: (position: CursorPosition) => void,
|
||||
reactFlowInstance?: ReactFlowInstance,
|
||||
): void {
|
||||
if (this.isTracking) this.stopTracking()
|
||||
|
||||
this.containerRef = containerRef
|
||||
this.onEmitPosition = onEmitPosition
|
||||
this.reactFlowInstance = reactFlowInstance || null
|
||||
this.isTracking = true
|
||||
|
||||
if (containerRef.current)
|
||||
containerRef.current.addEventListener('mousemove', this.handleMouseMove)
|
||||
}
|
||||
|
||||
stopTracking(): void {
|
||||
if (this.containerRef?.current)
|
||||
this.containerRef.current.removeEventListener('mousemove', this.handleMouseMove)
|
||||
|
||||
this.containerRef = null
|
||||
this.reactFlowInstance = null
|
||||
this.onEmitPosition = null
|
||||
this.isTracking = false
|
||||
this.lastPosition = null
|
||||
}
|
||||
|
||||
setCursorUpdateHandler(handler: (cursors: Record<string, CursorPosition>) => void): void {
|
||||
this.onCursorUpdate = handler
|
||||
}
|
||||
|
||||
updateCursors(cursors: Record<string, CursorPosition>): void {
|
||||
if (this.onCursorUpdate)
|
||||
this.onCursorUpdate(cursors)
|
||||
}
|
||||
|
||||
private handleMouseMove = (event: MouseEvent): void => {
|
||||
if (!this.containerRef?.current || !this.onEmitPosition) return
|
||||
|
||||
const rect = this.containerRef.current.getBoundingClientRect()
|
||||
let x = event.clientX - rect.left
|
||||
let y = event.clientY - rect.top
|
||||
|
||||
// Transform coordinates to ReactFlow world coordinates if ReactFlow instance is available
|
||||
if (this.reactFlowInstance) {
|
||||
const viewport = this.reactFlowInstance.getViewport()
|
||||
// Convert screen coordinates to world coordinates
|
||||
// World coordinates = (screen coordinates - viewport translation) / zoom
|
||||
x = (x - viewport.x) / viewport.zoom
|
||||
y = (y - viewport.y) / viewport.zoom
|
||||
}
|
||||
|
||||
// Always emit cursor position (remove boundary check since world coordinates can be negative)
|
||||
const now = Date.now()
|
||||
const timeThrottled = now - this.lastEmitTime > CURSOR_THROTTLE_MS
|
||||
const minDistance = CURSOR_MIN_MOVE_DISTANCE / (this.reactFlowInstance?.getZoom() || 1)
|
||||
const distanceThrottled = !this.lastPosition
|
||||
|| (Math.abs(x - this.lastPosition.x) > minDistance)
|
||||
|| (Math.abs(y - this.lastPosition.y) > minDistance)
|
||||
|
||||
if (timeThrottled && distanceThrottled) {
|
||||
this.lastPosition = { x, y }
|
||||
this.lastEmitTime = now
|
||||
this.onEmitPosition({
|
||||
x,
|
||||
y,
|
||||
userId: '',
|
||||
timestamp: now,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
import type { Edge, Node } from '../../types'
|
||||
|
||||
export type OnlineUser = {
|
||||
user_id: string
|
||||
username: string
|
||||
avatar: string
|
||||
sid: string
|
||||
}
|
||||
|
||||
export type WorkflowOnlineUsers = {
|
||||
workflow_id: string
|
||||
users: OnlineUser[]
|
||||
}
|
||||
|
||||
export type OnlineUserListResponse = {
|
||||
data: WorkflowOnlineUsers[]
|
||||
}
|
||||
|
||||
export type CursorPosition = {
|
||||
x: number
|
||||
y: number
|
||||
userId: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
export type NodePanelPresenceUser = {
|
||||
userId: string
|
||||
username: string
|
||||
avatar?: string | null
|
||||
}
|
||||
|
||||
export type NodePanelPresenceInfo = NodePanelPresenceUser & {
|
||||
clientId: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
export type NodePanelPresenceMap = Record<string, Record<string, NodePanelPresenceInfo>>
|
||||
|
||||
export type CollaborationState = {
|
||||
appId: string
|
||||
isConnected: boolean
|
||||
onlineUsers: OnlineUser[]
|
||||
cursors: Record<string, CursorPosition>
|
||||
nodePanelPresence: NodePanelPresenceMap
|
||||
}
|
||||
|
||||
export type GraphSyncData = {
|
||||
nodes: Node[]
|
||||
edges: Edge[]
|
||||
}
|
||||
|
||||
export type CollaborationUpdate = {
|
||||
type: 'mouseMove' | 'graphUpdate' | 'userJoin' | 'userLeave'
|
||||
userId: string
|
||||
data: any
|
||||
timestamp: number
|
||||
}
|
||||
38
web/app/components/workflow/collaboration/types/events.ts
Normal file
38
web/app/components/workflow/collaboration/types/events.ts
Normal file
@ -0,0 +1,38 @@
|
||||
export type CollaborationEvent = {
|
||||
type: string
|
||||
data: any
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
export type GraphUpdateEvent = {
|
||||
type: 'graph_update'
|
||||
data: Uint8Array
|
||||
} & CollaborationEvent
|
||||
|
||||
export type CursorMoveEvent = {
|
||||
type: 'cursor_move'
|
||||
data: {
|
||||
x: number
|
||||
y: number
|
||||
userId: string
|
||||
}
|
||||
} & CollaborationEvent
|
||||
|
||||
export type UserConnectEvent = {
|
||||
type: 'user_connect'
|
||||
data: {
|
||||
workflow_id: string
|
||||
}
|
||||
} & CollaborationEvent
|
||||
|
||||
export type OnlineUsersEvent = {
|
||||
type: 'online_users'
|
||||
data: {
|
||||
users: Array<{
|
||||
user_id: string
|
||||
username: string
|
||||
avatar: string
|
||||
sid: string
|
||||
}>
|
||||
}
|
||||
} & CollaborationEvent
|
||||
3
web/app/components/workflow/collaboration/types/index.ts
Normal file
3
web/app/components/workflow/collaboration/types/index.ts
Normal file
@ -0,0 +1,3 @@
|
||||
export * from './websocket'
|
||||
export * from './collaboration'
|
||||
export * from './events'
|
||||
16
web/app/components/workflow/collaboration/types/websocket.ts
Normal file
16
web/app/components/workflow/collaboration/types/websocket.ts
Normal file
@ -0,0 +1,16 @@
|
||||
export type WebSocketConfig = {
|
||||
url?: string
|
||||
token?: string
|
||||
transports?: string[]
|
||||
withCredentials?: boolean
|
||||
}
|
||||
|
||||
export type ConnectionInfo = {
|
||||
connected: boolean
|
||||
connecting: boolean
|
||||
socketId?: string
|
||||
}
|
||||
|
||||
export type DebugInfo = {
|
||||
[appId: string]: ConnectionInfo
|
||||
}
|
||||
@ -0,0 +1,12 @@
|
||||
/**
|
||||
* Generate a consistent color for a user based on their ID
|
||||
* Used for cursor colors and avatar backgrounds
|
||||
*/
|
||||
export const getUserColor = (id: string): string => {
|
||||
const colors = ['#155AEF', '#0BA5EC', '#444CE7', '#7839EE', '#4CA30D', '#0E9384', '#DD2590', '#FF4405', '#D92D20', '#F79009', '#828DAD']
|
||||
const hash = id.split('').reduce((a, b) => {
|
||||
a = ((a << 5) - a) + b.charCodeAt(0)
|
||||
return a & a
|
||||
}, 0)
|
||||
return colors[Math.abs(hash) % colors.length]
|
||||
}
|
||||
31
web/app/components/workflow/comment-manager.tsx
Normal file
31
web/app/components/workflow/comment-manager.tsx
Normal file
@ -0,0 +1,31 @@
|
||||
import { useEventListener } from 'ahooks'
|
||||
import { useWorkflowStore } from './store'
|
||||
import { useWorkflowComment } from './hooks/use-workflow-comment'
|
||||
|
||||
const CommentManager = () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { handleCreateComment } = useWorkflowComment()
|
||||
|
||||
useEventListener('click', (e) => {
|
||||
const { controlMode, mousePosition } = workflowStore.getState()
|
||||
|
||||
if (controlMode === 'comment') {
|
||||
const target = e.target as HTMLElement
|
||||
const isInDropdown = target.closest('[data-mention-dropdown]')
|
||||
const isInCommentInput = target.closest('[data-comment-input]')
|
||||
const isOnCanvasPane = target.closest('.react-flow__pane')
|
||||
|
||||
// Only when clicking on the React Flow canvas pane (background),
|
||||
// and not inside comment input or its dropdown
|
||||
if (!isInDropdown && !isInCommentInput && isOnCanvasPane) {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
handleCreateComment(mousePosition)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export default CommentManager
|
||||
239
web/app/components/workflow/comment/comment-icon.tsx
Normal file
239
web/app/components/workflow/comment/comment-icon.tsx
Normal file
@ -0,0 +1,239 @@
|
||||
'use client'
|
||||
|
||||
import type { FC, PointerEvent as ReactPointerEvent } from 'react'
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useReactFlow, useViewport } from 'reactflow'
|
||||
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
|
||||
import CommentPreview from './comment-preview'
|
||||
import type { WorkflowCommentList } from '@/service/workflow-comment'
|
||||
|
||||
type CommentIconProps = {
|
||||
comment: WorkflowCommentList
|
||||
onClick: () => void
|
||||
isActive?: boolean
|
||||
onPositionUpdate?: (position: { x: number; y: number }) => void
|
||||
}
|
||||
|
||||
export const CommentIcon: FC<CommentIconProps> = memo(({ comment, onClick, isActive = false, onPositionUpdate }) => {
|
||||
const { flowToScreenPosition, screenToFlowPosition } = useReactFlow()
|
||||
const viewport = useViewport()
|
||||
const [showPreview, setShowPreview] = useState(false)
|
||||
const [dragPosition, setDragPosition] = useState<{ x: number; y: number } | null>(null)
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const dragStateRef = useRef<{
|
||||
offsetX: number
|
||||
offsetY: number
|
||||
startX: number
|
||||
startY: number
|
||||
hasMoved: boolean
|
||||
} | null>(null)
|
||||
|
||||
const screenPosition = useMemo(() => {
|
||||
return flowToScreenPosition({
|
||||
x: comment.position_x,
|
||||
y: comment.position_y,
|
||||
})
|
||||
}, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition])
|
||||
|
||||
const effectivePosition = dragPosition ?? screenPosition
|
||||
|
||||
const handlePointerDown = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
|
||||
if (event.button !== 0)
|
||||
return
|
||||
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
|
||||
dragStateRef.current = {
|
||||
offsetX: event.clientX - screenPosition.x,
|
||||
offsetY: event.clientY - screenPosition.y,
|
||||
startX: event.clientX,
|
||||
startY: event.clientY,
|
||||
hasMoved: false,
|
||||
}
|
||||
|
||||
setDragPosition(screenPosition)
|
||||
setIsDragging(false)
|
||||
|
||||
if (event.currentTarget.dataset.role !== 'comment-preview')
|
||||
setShowPreview(false)
|
||||
|
||||
if (event.currentTarget.setPointerCapture)
|
||||
event.currentTarget.setPointerCapture(event.pointerId)
|
||||
}, [screenPosition])
|
||||
|
||||
const handlePointerMove = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
|
||||
const dragState = dragStateRef.current
|
||||
if (!dragState)
|
||||
return
|
||||
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
|
||||
const nextX = event.clientX - dragState.offsetX
|
||||
const nextY = event.clientY - dragState.offsetY
|
||||
|
||||
if (!dragState.hasMoved) {
|
||||
const distance = Math.hypot(event.clientX - dragState.startX, event.clientY - dragState.startY)
|
||||
if (distance > 4) {
|
||||
dragState.hasMoved = true
|
||||
setIsDragging(true)
|
||||
}
|
||||
}
|
||||
|
||||
setDragPosition({ x: nextX, y: nextY })
|
||||
}, [])
|
||||
|
||||
const finishDrag = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
|
||||
const dragState = dragStateRef.current
|
||||
if (!dragState)
|
||||
return false
|
||||
|
||||
if (event.currentTarget.hasPointerCapture?.(event.pointerId))
|
||||
event.currentTarget.releasePointerCapture(event.pointerId)
|
||||
|
||||
dragStateRef.current = null
|
||||
setDragPosition(null)
|
||||
setIsDragging(false)
|
||||
return dragState.hasMoved
|
||||
}, [])
|
||||
|
||||
const handlePointerUp = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
|
||||
const finalScreenPosition = dragPosition ?? screenPosition
|
||||
const didDrag = finishDrag(event)
|
||||
|
||||
setShowPreview(false)
|
||||
|
||||
if (didDrag) {
|
||||
if (onPositionUpdate) {
|
||||
const flowPosition = screenToFlowPosition({
|
||||
x: finalScreenPosition.x,
|
||||
y: finalScreenPosition.y,
|
||||
})
|
||||
onPositionUpdate(flowPosition)
|
||||
}
|
||||
}
|
||||
else if (!isActive) {
|
||||
onClick()
|
||||
}
|
||||
}, [dragPosition, finishDrag, isActive, onClick, onPositionUpdate, screenPosition, screenToFlowPosition])
|
||||
|
||||
const handlePointerCancel = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
finishDrag(event)
|
||||
}, [finishDrag])
|
||||
|
||||
const handleMouseEnter = useCallback(() => {
|
||||
if (isActive || isDragging)
|
||||
return
|
||||
setShowPreview(true)
|
||||
}, [isActive, isDragging])
|
||||
|
||||
const handleMouseLeave = useCallback(() => {
|
||||
setShowPreview(false)
|
||||
}, [])
|
||||
|
||||
const participants = useMemo(() => {
|
||||
const list = comment.participants ?? []
|
||||
const author = comment.created_by_account
|
||||
if (!author)
|
||||
return [...list]
|
||||
const rest = list.filter(user => user.id !== author.id)
|
||||
return [author, ...rest]
|
||||
}, [comment.created_by_account, comment.participants])
|
||||
|
||||
// Calculate dynamic width based on number of participants
|
||||
const participantCount = participants.length
|
||||
const maxVisible = Math.min(3, participantCount)
|
||||
const showCount = participantCount > 3
|
||||
const avatarSize = 24
|
||||
const avatarSpacing = 4 // -space-x-1 is about 4px overlap
|
||||
|
||||
// Width calculation: first avatar + (additional avatars * (size - spacing)) + padding
|
||||
const dynamicWidth = Math.max(40, // minimum width
|
||||
8 + avatarSize + Math.max(0, (showCount ? 2 : maxVisible - 1)) * (avatarSize - avatarSpacing) + 8,
|
||||
)
|
||||
|
||||
const pointerEventHandlers = useMemo(() => ({
|
||||
onPointerDown: handlePointerDown,
|
||||
onPointerMove: handlePointerMove,
|
||||
onPointerUp: handlePointerUp,
|
||||
onPointerCancel: handlePointerCancel,
|
||||
}), [handlePointerCancel, handlePointerDown, handlePointerMove, handlePointerUp])
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
className="absolute z-10"
|
||||
style={{
|
||||
left: effectivePosition.x,
|
||||
top: effectivePosition.y,
|
||||
transform: 'translate(-50%, -50%)',
|
||||
}}
|
||||
data-role='comment-marker'
|
||||
{...pointerEventHandlers}
|
||||
>
|
||||
<div
|
||||
className={isActive ? (isDragging ? 'cursor-grabbing' : '') : isDragging ? 'cursor-grabbing' : 'cursor-pointer'}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
>
|
||||
<div
|
||||
className={'relative h-10 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full'}
|
||||
style={{ width: dynamicWidth }}
|
||||
>
|
||||
<div className={`absolute inset-[6px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full border ${
|
||||
isActive
|
||||
? 'border-2 border-primary-500 bg-components-panel-bg'
|
||||
: 'border-components-panel-border bg-components-panel-bg'
|
||||
}`}>
|
||||
<div className="flex h-full w-full items-center justify-center px-1">
|
||||
<UserAvatarList
|
||||
users={participants}
|
||||
maxVisible={3}
|
||||
size={24}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Preview panel */}
|
||||
{showPreview && !isActive && (
|
||||
<div
|
||||
className="absolute z-20"
|
||||
style={{
|
||||
left: (dragPosition ?? screenPosition).x - dynamicWidth / 2,
|
||||
top: (dragPosition ?? screenPosition).y + 20,
|
||||
transform: 'translateY(-100%)',
|
||||
}}
|
||||
data-role='comment-preview'
|
||||
{...pointerEventHandlers}
|
||||
onMouseEnter={() => setShowPreview(true)}
|
||||
onMouseLeave={() => setShowPreview(false)}
|
||||
>
|
||||
<CommentPreview comment={comment} onClick={() => {
|
||||
setShowPreview(false)
|
||||
onClick()
|
||||
}} />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}, (prevProps, nextProps) => {
|
||||
return (
|
||||
prevProps.comment.id === nextProps.comment.id
|
||||
&& prevProps.comment.position_x === nextProps.comment.position_x
|
||||
&& prevProps.comment.position_y === nextProps.comment.position_y
|
||||
&& prevProps.onClick === nextProps.onClick
|
||||
&& prevProps.isActive === nextProps.isActive
|
||||
&& prevProps.onPositionUpdate === nextProps.onPositionUpdate
|
||||
)
|
||||
})
|
||||
|
||||
CommentIcon.displayName = 'CommentIcon'
|
||||
87
web/app/components/workflow/comment/comment-input.tsx
Normal file
87
web/app/components/workflow/comment/comment-input.tsx
Normal file
@ -0,0 +1,87 @@
|
||||
import type { FC } from 'react'
|
||||
import { memo, useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { MentionInput } from './mention-input'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type CommentInputProps = {
|
||||
position: { x: number; y: number }
|
||||
onSubmit: (content: string, mentionedUserIds: string[]) => void
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
export const CommentInput: FC<CommentInputProps> = memo(({ position, onSubmit, onCancel }) => {
|
||||
const [content, setContent] = useState('')
|
||||
const { t } = useTranslation()
|
||||
const { userProfile } = useAppContext()
|
||||
|
||||
useEffect(() => {
|
||||
const handleGlobalKeyDown = (e: KeyboardEvent) => {
|
||||
if (e.key === 'Escape') {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
onCancel()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleGlobalKeyDown, true)
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleGlobalKeyDown, true)
|
||||
}
|
||||
}, [onCancel])
|
||||
|
||||
const handleMentionSubmit = useCallback((content: string, mentionedUserIds: string[]) => {
|
||||
onSubmit(content, mentionedUserIds)
|
||||
setContent('')
|
||||
}, [onSubmit])
|
||||
|
||||
return (
|
||||
<div
|
||||
className="absolute z-50 w-96"
|
||||
style={{
|
||||
left: position.x,
|
||||
top: position.y,
|
||||
}}
|
||||
data-comment-input
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="relative shrink-0">
|
||||
<div className="relative h-8 w-8 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-primary-500">
|
||||
<div className="absolute inset-[2px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-white">
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<div className="h-6 w-6 overflow-hidden rounded-full">
|
||||
<Avatar
|
||||
avatar={userProfile.avatar_url}
|
||||
name={userProfile.name}
|
||||
size={24}
|
||||
className="h-full w-full"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
'relative z-10 flex-1 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[4px] shadow-md',
|
||||
)}
|
||||
>
|
||||
<div className='relative px-[9px] pt-[4px]'>
|
||||
<MentionInput
|
||||
value={content}
|
||||
onChange={setContent}
|
||||
onSubmit={handleMentionSubmit}
|
||||
placeholder={t('workflow.comments.placeholder.add')}
|
||||
autoFocus
|
||||
className="relative"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
CommentInput.displayName = 'CommentInput'
|
||||
52
web/app/components/workflow/comment/comment-preview.tsx
Normal file
52
web/app/components/workflow/comment/comment-preview.tsx
Normal file
@ -0,0 +1,52 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
|
||||
import type { WorkflowCommentList } from '@/service/workflow-comment'
|
||||
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
|
||||
|
||||
type CommentPreviewProps = {
|
||||
comment: WorkflowCommentList
|
||||
onClick?: () => void
|
||||
}
|
||||
|
||||
const CommentPreview: FC<CommentPreviewProps> = ({ comment, onClick }) => {
|
||||
const { formatTimeFromNow } = useFormatTimeFromNow()
|
||||
const participants = useMemo(() => {
|
||||
const list = comment.participants ?? []
|
||||
const author = comment.created_by_account
|
||||
if (!author)
|
||||
return [...list]
|
||||
const rest = list.filter(user => user.id !== author.id)
|
||||
return [author, ...rest]
|
||||
}, [comment.created_by_account, comment.participants])
|
||||
|
||||
return (
|
||||
<div
|
||||
className="w-80 cursor-pointer rounded-br-xl rounded-tl-xl rounded-tr-xl border border-components-panel-border bg-components-panel-bg p-4 shadow-lg transition-colors hover:bg-components-panel-on-panel-item-bg-hover"
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="mb-3 flex items-center justify-between">
|
||||
<UserAvatarList
|
||||
users={participants}
|
||||
maxVisible={3}
|
||||
size={24}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mb-2 flex items-start">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
<div className="system-sm-medium truncate text-text-primary">{comment.created_by_account.name}</div>
|
||||
<div className="system-2xs-regular shrink-0 text-text-tertiary">
|
||||
{formatTimeFromNow(comment.updated_at * 1000)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="system-sm-regular break-words text-text-secondary">{comment.content}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CommentPreview)
|
||||
28
web/app/components/workflow/comment/cursor.tsx
Normal file
28
web/app/components/workflow/comment/cursor.tsx
Normal file
@ -0,0 +1,28 @@
|
||||
import type { FC } from 'react'
|
||||
import { memo } from 'react'
|
||||
import { useStore } from '../store'
|
||||
import { ControlMode } from '../types'
|
||||
import { Comment } from '@/app/components/base/icons/src/public/other'
|
||||
|
||||
export const CommentCursor: FC = memo(() => {
|
||||
const controlMode = useStore(s => s.controlMode)
|
||||
const mousePosition = useStore(s => s.mousePosition)
|
||||
|
||||
if (controlMode !== ControlMode.Comment)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div
|
||||
className="pointer-events-none absolute z-50 flex h-6 w-6 items-center justify-center"
|
||||
style={{
|
||||
left: mousePosition.elementX,
|
||||
top: mousePosition.elementY,
|
||||
transform: 'translate(-50%, -50%)',
|
||||
}}
|
||||
>
|
||||
<Comment className="text-text-primary" />
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
CommentCursor.displayName = 'CommentCursor'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user