Merge branch 'feat/mcp-06-18' into deploy/dev

This commit is contained in:
Novice
2025-10-20 10:58:27 +08:00
463 changed files with 10928 additions and 7216 deletions

View File

@ -22,6 +22,7 @@ from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
from libs.token import generate_csrf_token
from models.account import (
Account,
AccountIntegrate,
@ -76,6 +77,7 @@ logger = logging.getLogger(__name__)
class TokenPair(BaseModel):
access_token: str
refresh_token: str
csrf_token: str
REFRESH_TOKEN_PREFIX = "refresh_token:"
@ -403,10 +405,11 @@ class AccountService:
access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()
csrf_token = generate_csrf_token(account.id)
AccountService._store_refresh_token(refresh_token, account.id)
return TokenPair(access_token=access_token, refresh_token=refresh_token)
return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token)
@staticmethod
def logout(*, account: Account):
@ -431,8 +434,9 @@ class AccountService:
AccountService._delete_refresh_token(refresh_token, account.id)
AccountService._store_refresh_token(new_refresh_token, account.id)
csrf_token = generate_csrf_token(account.id)
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
@staticmethod
def load_logged_in_account(*, account_id: str):

View File

@ -10,7 +10,7 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought

View File

@ -18,7 +18,7 @@ from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
from services.billing_service import BillingService

View File

@ -7,7 +7,7 @@ from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fix
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models import Account, TenantAccountJoin, TenantAccountRole
class BillingService:

View File

@ -22,7 +22,7 @@ from core.memory.errors import MemorySyncTimeoutError
from core.model_runtime.entities.message_entities import PromptMessage
from core.variables.segments import VersionedMemoryValue
from core.workflow.constants import MEMORY_BLOCK_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import App, CreatorUserRole
@ -38,9 +38,7 @@ logger = logging.getLogger(__name__)
class ChatflowMemoryService:
@staticmethod
def get_persistent_memories(
app: App,
created_by: MemoryCreatedBy,
version: int | None = None
app: App, created_by: MemoryCreatedBy, version: int | None = None
) -> Sequence[MemoryBlock]:
if created_by.account_id:
created_by_role = CreatorUserRole.ACCOUNT
@ -50,15 +48,20 @@ class ChatflowMemoryService:
created_by_id = created_by.id
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
stmt = (
select(ChatflowMemoryVariable)
.distinct(ChatflowMemoryVariable.memory_id)
.where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
)
)
).order_by(ChatflowMemoryVariable.version.desc())
.order_by(ChatflowMemoryVariable.version.desc())
)
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
@ -67,7 +70,7 @@ class ChatflowMemoryService:
ChatflowMemoryVariable.conversation_id == None,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id,
ChatflowMemoryVariable.version == version
ChatflowMemoryVariable.version == version,
)
)
with Session(db.engine) as session:
@ -76,27 +79,29 @@ class ChatflowMemoryService:
@staticmethod
def get_session_memories(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlock]:
if version is None:
# If version not specified, get the latest version
stmt = select(ChatflowMemoryVariable).distinct(ChatflowMemoryVariable.memory_id).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id
stmt = (
select(ChatflowMemoryVariable)
.distinct(ChatflowMemoryVariable.memory_id)
.where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
)
)
).order_by(ChatflowMemoryVariable.version.desc())
.order_by(ChatflowMemoryVariable.version.desc())
)
else:
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.conversation_id == conversation_id,
ChatflowMemoryVariable.version == version
ChatflowMemoryVariable.version == version,
)
)
with Session(db.engine) as session:
@ -123,10 +128,7 @@ class ChatflowMemoryService:
node_id=memory.node_id,
conversation_id=memory.conversation_id,
name=memory.spec.name,
value=MemoryValueData(
value=memory.value,
edited_by_user=memory.edited_by_user
).model_dump_json(),
value=MemoryValueData(value=memory.value, edited_by_user=memory.edited_by_user).model_dump_json(),
term=memory.spec.term,
scope=memory.spec.scope,
version=memory.version, # Use version from MemoryBlock directly
@ -141,21 +143,22 @@ class ChatflowMemoryService:
draft_var_service = WorkflowDraftVariableService(session)
memory_selector = memory.spec.id if not memory.node_id else f"{memory.node_id}.{memory.spec.id}"
existing_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=memory.app_id,
selectors=[[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_selector]]
app_id=memory.app_id, selectors=[[MEMORY_BLOCK_VARIABLE_NODE_ID, memory_selector]]
)
if existing_vars:
draft_var = existing_vars[0]
draft_var.value = VersionedMemoryValue.model_validate_json(draft_var.value)\
.add_version(memory.value)\
draft_var.value = (
VersionedMemoryValue.model_validate_json(draft_var.value)
.add_version(memory.value)
.model_dump_json()
)
else:
draft_var = WorkflowDraftVariable.new_memory_block_variable(
app_id=memory.app_id,
memory_id=memory.spec.id,
name=memory.spec.name,
value=VersionedMemoryValue().add_version(memory.value),
description=memory.spec.description
description=memory.spec.description,
)
session.add(draft_var)
session.commit()
@ -163,15 +166,19 @@ class ChatflowMemoryService:
@staticmethod
def get_memories_by_specs(
memory_block_specs: Sequence[MemoryBlockSpec],
tenant_id: str, app_id: str,
tenant_id: str,
app_id: str,
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
is_draft: bool,
) -> Sequence[MemoryBlock]:
return [ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
) for spec in memory_block_specs]
return [
ChatflowMemoryService.get_memory_by_spec(
spec, tenant_id, app_id, created_by, conversation_id, node_id, is_draft
)
for spec in memory_block_specs
]
@staticmethod
def get_memory_by_spec(
@ -181,17 +188,17 @@ class ChatflowMemoryService:
created_by: MemoryCreatedBy,
conversation_id: Optional[str],
node_id: Optional[str],
is_draft: bool
is_draft: bool,
) -> MemoryBlock:
with Session(db.engine) as session:
if is_draft:
draft_var_service = WorkflowDraftVariableService(session)
selector = [MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"] \
if node_id else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
draft_vars = draft_var_service.get_draft_variables_by_selectors(
app_id=app_id,
selectors=[selector]
selector = (
[MEMORY_BLOCK_VARIABLE_NODE_ID, f"{spec.id}.{node_id}"]
if node_id
else [MEMORY_BLOCK_VARIABLE_NODE_ID, spec.id]
)
draft_vars = draft_var_service.get_draft_variables_by_selectors(app_id=app_id, selectors=[selector])
if draft_vars:
draft_var = draft_vars[0]
return MemoryBlock(
@ -204,17 +211,21 @@ class ChatflowMemoryService:
created_by=created_by,
version=1,
)
stmt = select(ChatflowMemoryVariable).where(
and_(
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id ==
(node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id ==
(conversation_id if spec.term == MemoryTerm.SESSION else None),
stmt = (
select(ChatflowMemoryVariable)
.where(
and_(
ChatflowMemoryVariable.memory_id == spec.id,
ChatflowMemoryVariable.tenant_id == tenant_id,
ChatflowMemoryVariable.app_id == app_id,
ChatflowMemoryVariable.node_id == (node_id if spec.scope == MemoryScope.NODE else None),
ChatflowMemoryVariable.conversation_id
== (conversation_id if spec.term == MemoryTerm.SESSION else None),
)
)
).order_by(ChatflowMemoryVariable.version.desc()).limit(1)
.order_by(ChatflowMemoryVariable.version.desc())
.limit(1)
)
result = session.execute(stmt).scalar()
if result:
memory_value_data = MemoryValueData.model_validate_json(result.value)
@ -246,7 +257,7 @@ class ChatflowMemoryService:
conversation_id: str,
variable_pool: VariablePool,
created_by: MemoryCreatedBy,
is_draft: bool
is_draft: bool,
):
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
@ -294,7 +305,7 @@ class ChatflowMemoryService:
conversation_id=conversation_id,
app_id=workflow.app_id,
visible_messages=visible_messages,
variable_pool=variable_pool
variable_pool=variable_pool,
)
@staticmethod
@ -306,7 +317,7 @@ class ChatflowMemoryService:
conversation_id: str,
memory_block_spec: MemoryBlockSpec,
variable_pool: VariablePool,
is_draft: bool
is_draft: bool,
) -> bool:
visible_messages = ChatflowHistoryService.get_visible_chat_history(
conversation_id=conversation_id,
@ -323,10 +334,7 @@ class ChatflowMemoryService:
is_draft=is_draft,
created_by=created_by,
)
if not ChatflowMemoryService._should_update_memory(
memory_block=memory_block,
visible_history=visible_messages
):
if not ChatflowMemoryService._should_update_memory(memory_block=memory_block, visible_history=visible_messages):
return False
if memory_block_spec.schedule_mode == MemoryScheduleMode.SYNC:
@ -336,7 +344,7 @@ class ChatflowMemoryService:
memory_block=memory_block,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
conversation_id=conversation_id,
)
else:
# Node-level async: execute asynchronously
@ -345,7 +353,7 @@ class ChatflowMemoryService:
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
conversation_id=conversation_id,
)
return True
@ -355,7 +363,8 @@ class ChatflowMemoryService:
memory_blocks = workflow.memory_blocks
sync_memory_blocks = [
block for block in memory_blocks
block
for block in memory_blocks
if block.scope == MemoryScope.APP and block.schedule_mode == MemoryScheduleMode.SYNC
]
@ -378,16 +387,11 @@ class ChatflowMemoryService:
time.sleep(retry_interval)
else:
# Maximum retry attempts reached, raise exception
raise MemorySyncTimeoutError(
app_id=workflow.app_id,
conversation_id=conversation_id
)
raise MemorySyncTimeoutError(app_id=workflow.app_id, conversation_id=conversation_id)
@staticmethod
def _convert_to_memory_blocks(
app: App,
created_by: MemoryCreatedBy,
raw_results: Sequence[ChatflowMemoryVariable]
app: App, created_by: MemoryCreatedBy, raw_results: Sequence[ChatflowMemoryVariable]
) -> Sequence[MemoryBlock]:
workflow = WorkflowService().get_published_workflow(app)
if not workflow:
@ -395,8 +399,7 @@ class ChatflowMemoryService:
results = []
for chatflow_memory_variable in raw_results:
spec = next(
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id),
None
(spec for spec in workflow.memory_blocks if spec.id == chatflow_memory_variable.memory_id), None
)
if spec and chatflow_memory_variable.app_id:
memory_value_data = MemoryValueData.model_validate_json(chatflow_memory_variable.value)
@ -416,10 +419,7 @@ class ChatflowMemoryService:
return results
@staticmethod
def _should_update_memory(
memory_block: MemoryBlock,
visible_history: Sequence[PromptMessage]
) -> bool:
def _should_update_memory(memory_block: MemoryBlock, visible_history: Sequence[PromptMessage]) -> bool:
return len(visible_history) >= memory_block.spec.update_turns
@staticmethod
@ -428,16 +428,16 @@ class ChatflowMemoryService:
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
is_draft: bool,
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id
"memory_block": block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
)
thread.start()
@ -449,18 +449,18 @@ class ChatflowMemoryService:
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
is_draft: bool,
):
"""Submit sync memory batch update task"""
thread = threading.Thread(
target=ChatflowMemoryService._batch_update_sync_memory,
kwargs={
'sync_blocks': sync_blocks,
'app_id': app_id,
'conversation_id': conversation_id,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft
"sync_blocks": sync_blocks,
"app_id": app_id,
"conversation_id": conversation_id,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
},
)
thread.start()
@ -472,7 +472,7 @@ class ChatflowMemoryService:
conversation_id: str,
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
is_draft: bool
is_draft: bool,
):
try:
lock_key = _get_memory_sync_lock_key(app_id, conversation_id)
@ -482,11 +482,11 @@ class ChatflowMemoryService:
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
"memory_block": block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
)
threads.append(thread)
@ -503,14 +503,14 @@ class ChatflowMemoryService:
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool
is_draft: bool,
):
ChatflowMemoryService._perform_memory_update(
memory_block=memory_block,
visible_messages=visible_messages,
variable_pool=variable_pool,
is_draft=is_draft,
conversation_id=conversation_id
conversation_id=conversation_id,
)
@staticmethod
@ -519,18 +519,18 @@ class ChatflowMemoryService:
visible_messages: Sequence[PromptMessage],
variable_pool: VariablePool,
conversation_id: str,
is_draft: bool = False
is_draft: bool = False,
):
thread = threading.Thread(
target=ChatflowMemoryService._perform_memory_update,
kwargs={
'memory_block': memory_block,
'visible_messages': visible_messages,
'variable_pool': variable_pool,
'is_draft': is_draft,
'conversation_id': conversation_id,
"memory_block": memory_block,
"visible_messages": visible_messages,
"variable_pool": variable_pool,
"is_draft": is_draft,
"conversation_id": conversation_id,
},
daemon=True
daemon=True,
)
thread.start()
@ -540,7 +540,7 @@ class ChatflowMemoryService:
variable_pool: VariablePool,
conversation_id: str,
visible_messages: Sequence[PromptMessage],
is_draft: bool
is_draft: bool,
):
updated_value = LLMGenerator.update_memory_block(
tenant_id=memory_block.tenant_id,
@ -567,7 +567,7 @@ class ChatflowMemoryService:
node_id=memory_block.node_id,
new_visible_count=memory_block.spec.preserved_turns,
app_id=memory_block.app_id,
tenant_id=memory_block.tenant_id
tenant_id=memory_block.tenant_id,
)
@staticmethod
@ -594,7 +594,7 @@ class ChatflowMemoryService:
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.memory_id == memory_id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
ChatflowMemoryVariable.created_by == created_by_id,
)
)
session.execute(stmt)
@ -615,7 +615,7 @@ class ChatflowMemoryService:
ChatflowMemoryVariable.tenant_id == app.tenant_id,
ChatflowMemoryVariable.app_id == app.id,
ChatflowMemoryVariable.created_by_role == created_by_role,
ChatflowMemoryVariable.created_by == created_by_id
ChatflowMemoryVariable.created_by == created_by_id,
)
)
session.execute(stmt)
@ -623,38 +623,28 @@ class ChatflowMemoryService:
@staticmethod
def get_persistent_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get persistent memories with conversation metadata (always None for persistent)"""
memory_blocks = ChatflowMemoryService.get_persistent_memories(app, created_by, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
ChatflowHistoryService.get_conversation_metadata(app.tenant_id, app.id, conversation_id, block.node_id),
)
for block in memory_blocks
]
@staticmethod
def get_session_memories_with_conversation(
app: App,
created_by: MemoryCreatedBy,
conversation_id: str,
version: int | None = None
app: App, created_by: MemoryCreatedBy, conversation_id: str, version: int | None = None
) -> Sequence[MemoryBlockWithConversation]:
"""Get session memories with conversation metadata"""
memory_blocks = ChatflowMemoryService.get_session_memories(app, created_by, conversation_id, version)
return [
MemoryBlockWithConversation.from_memory_block(
block,
ChatflowHistoryService.get_conversation_metadata(
app.tenant_id, app.id, conversation_id, block.node_id
)
ChatflowHistoryService.get_conversation_metadata(app.tenant_id, app.id, conversation_id, block.node_id),
)
for block in memory_blocks
]

View File

@ -14,8 +14,7 @@ from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import ConversationVariable
from models.account import Account
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
from services.errors.conversation import (
ConversationNotExistsError,

View File

@ -29,7 +29,7 @@ from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole
from models import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
ChildChunk,

View File

@ -17,7 +17,6 @@ from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
def get_current_user():
from libs.login import current_user
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
return current_user
class DatasourceProviderService:
"""
Model Provider Service
@ -93,8 +102,6 @@ class DatasourceProviderService:
"""
get credential by id
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
if credential_id:
datasource_provider = (
@ -111,6 +118,7 @@ class DatasourceProviderService:
return {}
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
current_user = get_current_user()
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
@ -159,8 +167,6 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_providers = (
session.query(DatasourceProvider)
@ -170,6 +176,7 @@ class DatasourceProviderService:
)
if not datasource_providers:
return []
current_user = get_current_user()
# refresh the credentials
real_credentials_list = []
for datasource_provider in datasource_providers:
@ -608,7 +615,6 @@ class DatasourceProviderService:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
@ -630,6 +636,7 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists")
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
@ -907,7 +914,6 @@ class DatasourceProviderService:
"""
update datasource credentials.
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_provider = (
@ -944,6 +950,7 @@ class DatasourceProviderService:
for key, value in credentials.items()
}
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,

View File

@ -46,17 +46,17 @@ class EnterpriseService:
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
params = {"userId": user_id, "appCode": app_code}
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
params = {"userId": user_id, "appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
if not app_codes:
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]):
if not app_ids:
return {}
body = {"userId": user_id, "appCodes": app_codes}
body = {"userId": user_id, "appIds": app_ids}
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
if not data:
raise ValueError("No data found.")

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from models.account import Account
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile

View File

@ -9,7 +9,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.dataset import Dataset, DatasetQuery
logger = logging.getLogger(__name__)

View File

@ -12,7 +12,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models import Account
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.message import (

View File

@ -1,12 +1,11 @@
import copy
import logging
from flask_login import current_user
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
@ -23,11 +22,11 @@ class MetadataService:
# check if metadata name is too long
if len(metadata_args.name) > 255:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.")
@ -35,7 +34,7 @@ class MetadataService:
if field.value == metadata_args.name:
raise ValueError("Metadata name already exists in Built-in fields.")
metadata = DatasetMetadata(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
dataset_id=dataset_id,
type=metadata_args.type,
name=metadata_args.name,
@ -53,9 +52,10 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.")
@ -220,9 +220,10 @@ class MetadataService:
db.session.commit()
# deal metadata binding
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
dataset_id=dataset.id,
document_id=operation.document_id,
metadata_id=metadata_value.id,

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models import Account
from models.model import OAuthProviderApp
from services.account_service import AccountService

View File

@ -336,6 +336,8 @@ class PluginService:
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
PluginService._check_plugin_installation_scope(response.verification)
return response
@staticmethod
@ -358,6 +360,8 @@ class PluginService:
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
PluginService._check_plugin_installation_scope(response.verification)
return response
@staticmethod
@ -377,6 +381,10 @@ class PluginService:
manager = PluginInstaller()
for plugin_unique_identifier in plugin_unique_identifiers:
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(resp.verification)
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
@ -393,6 +401,9 @@ class PluginService:
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
return manager.install_from_identifiers(
tenant_id,
[plugin_unique_identifier],

View File

@ -1,7 +1,7 @@
import yaml
from flask_login import current_user
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
@ -13,9 +13,8 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_customized(
tenant_id=current_user.current_tenant_id, language=language
)
_, current_tenant_id = current_account_with_tenant()
result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
return result
def get_pipeline_template_detail(self, template_id: str):

View File

@ -37,7 +37,6 @@ from core.rag.entities.event import (
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@ -50,11 +49,12 @@ from core.workflow.node_events.base import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models import Account
from models.dataset import ( # type: ignore
Dataset,
Document,

View File

@ -2,7 +2,7 @@ from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models import Account
from models.model import App, EndUser
from models.web import SavedMessage
from services.message_service import MessageService

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models import Account
from models.model import App, EndUser
from models.web import PinnedConversation
from services.conversation_service import ConversationService

View File

@ -10,7 +10,7 @@ from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models import Account, AccountStatus
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -172,7 +172,8 @@ class WebAppAuthService:
return WebAppAuthType.EXTERNAL
if app_code:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
app_id = AppService.get_app_id_by_code(app_code)
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
raise ValueError("Could not determine app authentication type.")

View File

@ -22,7 +22,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import App, AppMode, AppModelConfig
from models.workflow import Workflow, WorkflowType

View File

@ -32,8 +32,7 @@ from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models import App, Conversation
from models.account import Account
from models import Account, App, Conversation
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory

View File

@ -26,13 +26,15 @@ class WorkflowRunService:
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
def get_paginate_advanced_chat_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
Only return triggered_from == advanced_chat
:param app_model: app model
:param args: request args
:param triggered_from: workflow run triggered from (default: DEBUGGING for preview runs)
"""
class WorkflowWithMessage:
@ -45,7 +47,7 @@ class WorkflowRunService:
def __getattr__(self, item):
return getattr(self._workflow_run, item)
pagination = self.get_paginate_workflow_runs(app_model, args)
pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from)
with_message_workflow_runs = []
for workflow_run in pagination.data:
@ -60,23 +62,27 @@ class WorkflowRunService:
pagination.data = with_message_workflow_runs
return pagination
def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
def get_paginate_workflow_runs(
self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING
) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
Get workflow run list
:param app_model: app model
:param args: request args
:param triggered_from: workflow run triggered from (default: DEBUGGING)
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
status = args.get("status")
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
triggered_from=triggered_from,
limit=limit,
last_id=last_id,
status=status,
)
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
@ -92,6 +98,30 @@ class WorkflowRunService:
run_id=run_id,
)
def get_workflow_runs_count(
self,
app_model: App,
status: str | None = None,
time_range: str | None = None,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
) -> dict[str, int]:
"""
Get workflow runs count statistics
:param app_model: app model
:param status: optional status filter
:param time_range: optional time range filter (e.g., "7d", "4h", "30m", "30s")
:param triggered_from: workflow run triggered from (default: DEBUGGING)
:return: dict with total and status counts
"""
return self._workflow_run_repo.get_workflow_runs_count(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=triggered_from,
status=status,
time_range=time_range,
)
def get_workflow_run_node_executions(
self,
app_model: App,

View File

@ -15,7 +15,7 @@ from core.memory.entities import MemoryBlockSpec, MemoryCreatedBy, MemoryScope
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import VariablePool, WorkflowNodeExecution
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -31,7 +32,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models.account import Account
from models import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType