Merge branch 'main' into feat/mcp

This commit is contained in:
Novice
2025-05-28 09:37:55 +08:00
799 changed files with 22592 additions and 6640 deletions

View File

@ -49,7 +49,7 @@ from services.errors.account import (
RoleAlreadyAssignedError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
@ -108,17 +108,20 @@ class AccountService:
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if not available_ta:
return None
account.current_tenant_id = available_ta.tenant_id
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
db.session.commit()
@ -297,9 +300,9 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider
).first()
account_integrate: Optional[AccountIntegrate] = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
if account_integrate:
# If it exists, update the record
@ -612,7 +615,10 @@ class TenantService:
):
"""Check if user have a workspace or not"""
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if available_ta:
@ -622,6 +628,10 @@ class TenantService:
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
raise WorkSpaceNotAllowedCreateError()
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
raise WorkspacesLimitExceededError()
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else:
@ -666,7 +676,7 @@ class TenantService:
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
tenant.role = ta.role
else:
@ -695,12 +705,12 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
TenantAccountJoin.query.filter(
db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()
@staticmethod
@ -787,7 +797,7 @@ class TenantService:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.")
@ -800,7 +810,7 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
@ -812,15 +822,23 @@ class TenantService:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
target_member_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
)
if not target_member_join:
raise MemberNotInTenantError("Member not in tenant.")
if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join.role = "admin"
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if current_owner_join:
current_owner_join.role = "admin"
# Update the role of the target member
target_member_join.role = new_role
@ -837,7 +855,7 @@ class TenantService:
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
@ -914,7 +932,11 @@ class RegisterService:
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
if (
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
and FeatureService.get_system_features().license.workspaces.is_available()
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
@ -959,7 +981,7 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
TenantService.create_tenant_member(tenant, account, role)

View File

@ -4,7 +4,7 @@ from typing import cast
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -124,8 +124,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
@ -133,14 +134,14 @@ class AppAnnotationService:
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
else:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
@classmethod
@ -325,13 +326,16 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (
AppAnnotationHitHistory.query.filter(
stmt = (
select(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total

View File

@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.2.0"
CURRENT_DSL_VERSION = "0.3.0"
class ImportMode(StrEnum):

View File

@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@ -155,6 +157,10 @@ class AppService:
app_was_created.send(app, account=account)
if FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
return app
def get_app(self, app: App) -> App:
@ -307,6 +313,10 @@ class AppService:
db.session.delete(app)
db.session.commit()
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
@ -373,3 +383,15 @@ class AppService:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
return meta
@staticmethod
def get_app_code_by_id(app_id: str) -> str:
"""
Get app code by app id
:param app_id: app id
:return: app code
"""
site = db.session.query(Site).filter(Site.app_id == app_id).first()
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)

View File

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -129,7 +131,7 @@ class DatasetService:
else:
return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total
@ -153,9 +155,10 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@ -174,7 +177,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
@ -235,7 +238,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -436,7 +439,7 @@ class DatasetService:
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
@ -460,7 +463,7 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
@ -474,15 +477,15 @@ class DatasetService:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
if (
not user_permission
and dataset.tenant_id != user.current_tenant_id
and dataset.created_by != user.id
):
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
# For partial team permission, user needs explicit permission or be the creator
if dataset.created_by != user.id:
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if not user_permission:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
@ -499,23 +502,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = (
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
@staticmethod
def get_related_apps(dataset_id: str):
return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@ -530,10 +534,14 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -873,7 +881,9 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document:
return document.position + 1
else:
@ -948,11 +958,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:
@ -980,7 +990,7 @@ class DocumentService:
created_by=account.id,
)
else:
logging.warn(
logging.warning(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
@ -1010,13 +1020,17 @@ class DocumentService:
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1054,12 +1068,16 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
@ -1067,14 +1085,18 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
@ -1206,12 +1228,16 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
).count()
documents_count = (
db.session.query(Document)
.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count
@staticmethod
@ -1278,14 +1304,18 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
@ -1328,7 +1358,7 @@ class DocumentService:
db.session.commit()
# update document segment
update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@ -1918,7 +1948,8 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@ -2157,20 +2188,28 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = ChildChunk.query.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
).order_by(ChildChunk.position.asc())
query = (
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@classmethod
@ -2184,7 +2223,7 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
@ -2194,9 +2233,8 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
@ -2236,9 +2274,11 @@ class SegmentService:
raise ValueError(ex.description)
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@ -2251,9 +2291,11 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""
result = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
).first()
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None

View File

@ -1,11 +1,90 @@
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'",
default="private",
alias="accessMode",
)
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_app_web_sso_enabled(cls, app_code):
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
params = {"userId": user_id, "appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
if not app_id:
raise ValueError("app_id must be provided.")
params = {"appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
if not app_ids:
return {}
body = {"appIds": app_ids}
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
if not data:
raise ValueError("No data found.")
if not isinstance(data["accessModes"], dict):
raise ValueError("Invalid data format.")
ret = {}
for key, value in data["accessModes"].items():
curr = WebAppSettings()
curr.access_mode = value
ret[key] = curr
return ret
@classmethod
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
if not app_code:
raise ValueError("app_code must be provided.")
params = {"appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
data = {"appId": app_id, "accessMode": access_mode}
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
return response.get("result", False)
@classmethod
def cleanup_webapp(cls, app_id: str):
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)

View File

@ -0,0 +1,18 @@
from pydantic import BaseModel
from tasks.mail_enterprise_task import send_enterprise_email_task
class DifyMail(BaseModel):
to: list[str]
subject: str
body: str
substitutions: dict[str, str] = {}
class EnterpriseMailService:
@classmethod
def send_mail(cls, mail: DifyMail):
send_enterprise_email_task.delay(
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
)

View File

@ -7,3 +7,7 @@ class WorkSpaceNotAllowedCreateError(BaseServiceError):
class WorkSpaceNotFoundError(BaseServiceError):
pass
class WorkspacesLimitExceededError(BaseServiceError):
pass

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
ExternalKnowledgeApis.created_at.desc()
def get_external_knowledge_apis(
page, per_page, tenant_id, search=None
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total
@ -92,18 +99,18 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +127,9 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -131,25 +138,29 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
@ -212,11 +223,13 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -254,15 +267,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_binding.external_knowledge_api_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api:
raise ValueError("external api template not found")

View File

@ -1,6 +1,6 @@
from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from services.billing_service import BillingService
@ -27,6 +27,32 @@ class LimitationModel(BaseModel):
limit: int = 0
class LicenseLimitationModel(BaseModel):
"""
- enabled: whether this limit is enforced
- size: current usage count
- limit: maximum allowed count; 0 means unlimited
"""
enabled: bool = Field(False, description="Whether this limit is currently active")
size: int = Field(0, description="Number of resources already consumed")
limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit")
def is_available(self, required: int = 1) -> bool:
"""
Determine whether the requested amount can be allocated.
Returns True if:
- this limit is not active, or
- the limit is zero (unlimited), or
- there is enough remaining quota.
"""
if not self.enabled or self.limit == 0:
return True
return (self.limit - self.size) >= required
class LicenseStatus(StrEnum):
NONE = "none"
INACTIVE = "inactive"
@ -39,6 +65,27 @@ class LicenseStatus(StrEnum):
class LicenseModel(BaseModel):
status: LicenseStatus = LicenseStatus.NONE
expired_at: str = ""
workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
class BrandingModel(BaseModel):
enabled: bool = False
application_title: str = ""
login_page_logo: str = ""
workspace_logo: str = ""
favicon: str = ""
class WebAppAuthSSOModel(BaseModel):
protocol: str = ""
class WebAppAuthModel(BaseModel):
enabled: bool = False
allow_sso: bool = False
sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel()
allow_email_code_login: bool = False
allow_email_password_login: bool = False
class FeatureModel(BaseModel):
@ -54,6 +101,8 @@ class FeatureModel(BaseModel):
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
webapp_copyright_enabled: bool = False
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -68,9 +117,6 @@ class KnowledgeRateLimitModel(BaseModel):
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
enable_marketplace: bool = False
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False
@ -80,6 +126,8 @@ class SystemFeatureModel(BaseModel):
is_allow_create_workspace: bool = False
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
class FeatureService:
@ -92,6 +140,10 @@ class FeatureService:
if dify_config.BILLING_ENABLED and tenant_id:
cls._fulfill_params_from_billing_api(features, tenant_id)
if dify_config.ENTERPRISE_ENABLED:
features.webapp_copyright_enabled = True
cls._fulfill_params_from_workspace_info(features, tenant_id)
return features
@classmethod
@ -111,8 +163,8 @@ class FeatureService:
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
@ -136,6 +188,14 @@ class FeatureService:
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod
def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str):
workspace_info = EnterpriseService.get_workspace_info(tenant_id)
if "WorkspaceMembers" in workspace_info:
features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"]
features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"]
features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"]
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
@ -145,6 +205,9 @@ class FeatureService:
features.billing.subscription.interval = billing_info["subscription"]["interval"]
features.education.activated = billing_info["subscription"].get("education", False)
if features.billing.subscription.plan != "sandbox":
features.webapp_copyright_enabled = True
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
features.members.limit = billing_info["members"]["limit"]
@ -178,38 +241,53 @@ class FeatureService:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
@classmethod
def _fulfill_params_from_enterprise(cls, features):
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
enterprise_info = EnterpriseService.get_info()
if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
if "SSOEnforcedForSignin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"]
if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "SSOEnforcedForSigninProtocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"]
if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "EnableEmailCodeLogin" in enterprise_info:
features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"]
if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "EnableEmailPasswordLogin" in enterprise_info:
features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"]
if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "IsAllowRegister" in enterprise_info:
features.is_allow_register = enterprise_info["IsAllowRegister"]
if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "IsAllowCreateWorkspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"]
if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"]
if "Branding" in enterprise_info:
features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "")
features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "")
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
if "WebAppAuth" in enterprise_info:
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False)
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
"allowEmailCodeLogin", False
)
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
"allowEmailPasswordLogin", False
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if "license" in enterprise_info:
license_info = enterprise_info["license"]
if "License" in enterprise_info:
license_info = enterprise_info["License"]
if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expired_at" in license_info:
features.license.expired_at = license_info["expired_at"]
if "expiredAt" in license_info:
features.license.expired_at = license_info["expiredAt"]
if "workspaces" in license_info:
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
features.license.workspaces.limit = license_info["workspaces"]["limit"]
features.license.workspaces.size = license_info["workspaces"]["used"]

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
@ -81,7 +81,7 @@ class FileService:
size=file_size,
extension=extension,
mime_type=mimetype,
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER),
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
created_by=user.id,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=False,
@ -133,7 +133,7 @@ class FileService:
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_by_role=CreatedByRole.ACCOUNT,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True,
used_by=current_user.id,

View File

@ -69,6 +69,7 @@ class HitTestingService:
query: str,
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
if dataset.provider != "external":
return {
@ -82,6 +83,7 @@ class HitTestingService:
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
metadata_filtering_conditions=metadata_filtering_conditions,
)
end = time.perf_counter()

View File

@ -177,6 +177,21 @@ class MessageService:
return feedback
@classmethod
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
"""Get all feedbacks of an app"""
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
.all()
)
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = (

View File

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document)
db.session.commit()
# deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by(
metadata_id=item.get("id"), dataset_id=dataset.id
).count(),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@ -87,7 +87,9 @@ class OpsService:
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map and tracing_provider:
try:
provider_config_map[tracing_provider]
except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
config_class, other_keys = (
@ -150,7 +152,9 @@ class OpsService:
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
# check if trace config already exists

View File

@ -17,7 +17,7 @@ from core.plugin.entities.plugin import (
PluginInstallation,
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginListResponse, PluginUploadResponse
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
@ -110,6 +110,15 @@ class PluginService:
plugins = manager.list_plugins(tenant_id)
return plugins
@staticmethod
def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse:
"""
list all plugins of the tenant
"""
manager = PluginInstaller()
plugins = manager.list_plugins_with_total(tenant_id, page, page_size)
return plugins
@staticmethod
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
"""

View File

@ -1,3 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
@ -12,21 +13,30 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents = []
documents: list[Document] = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
segment.id,
)
continue
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@ -50,9 +60,11 @@ class VectorService:
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
cls.generate_child_chunks(
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
)
else:
document = Document(
rag_document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
@ -61,7 +73,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
documents.append(rag_document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)

View File

@ -0,0 +1,141 @@
import random
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web.error import WebAppAuthAccessDeniedError
from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthService:
"""Service for web app authentication."""
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
return cast(Account, account)
@classmethod
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
return access_token
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
return account
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login")
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model:
raise NotFound("App not found.")
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=False,
session_id=email,
name="enterpriseuser",
external_user_id="enterpriseuser",
)
db.session.add(end_user)
db.session.commit()
return end_user
@classmethod
def _validate_user_accessibility(cls, account: Account, app_code: str):
"""Check if the user is allowed to access the app."""
system_features = FeatureService.get_system_features()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if (
app_settings.access_mode != "public"
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
):
raise WebAppAuthAccessDeniedError()
@classmethod
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id,
"end_user_id": end_user_id,
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return token

View File

@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from models import App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.workflow import WorkflowRunStatus
@ -58,7 +58,7 @@ class WorkflowAppService:
stmt = stmt.outerjoin(
EndUser,
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER),
).where(or_(*keyword_conditions))
if status:

View File

@ -1,17 +1,21 @@
import threading
from collections.abc import Sequence
from typing import Optional
import contexts
from core.workflow.repository import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import (
from models import (
Account,
App,
EndUser,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunTriggeredFrom,
)
from models.workflow import WorkflowNodeExecutionTriggeredFrom
class WorkflowRunService:
@ -116,7 +120,12 @@ class WorkflowRunService:
return workflow_run
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
def get_workflow_run_node_executions(
self,
app_model: App,
run_id: str,
user: Account | EndUser,
) -> Sequence[WorkflowNodeExecution]:
"""
Get workflow run node execution list
"""
@ -128,17 +137,17 @@ class WorkflowRunService:
if not workflow_run:
return []
# Use the repository to get the node executions
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind(),
}
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Use the repository to get the node executions with ordering
# Use the repository to get the database models directly
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
workflow_node_executions = repository.get_db_models_by_workflow_run(
workflow_run_id=run_id, order_config=order_config
)
return list(node_executions)
return workflow_node_executions

View File

@ -10,9 +10,10 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@ -21,12 +22,10 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repository import RepositoryFactory
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
@ -268,37 +267,37 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=app_model.tenant_id,
node_id=node_id,
)
workflow_node_execution.app_id = app_model.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
# Set workflow_id on the NodeExecution
node_execution.workflow_id = draft_workflow.id
# Use the repository to save the workflow node execution
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind(),
}
# Create repository and save the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
user=account,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
repository.save(workflow_node_execution)
repository.save(node_execution)
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution = repository.to_db_model(node_execution)
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
) -> NodeExecution:
"""
Run draft workflow node
"""
@ -306,7 +305,7 @@ class WorkflowService:
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
@ -314,7 +313,6 @@ class WorkflowService:
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id,
)
@ -322,21 +320,12 @@ class WorkflowService:
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
) -> WorkflowNodeExecution:
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
) -> NodeExecution:
try:
node_instance, generator = getter()
node_instance, generator = invoke_node_fn()
node_run_result: NodeRunResult | None = None
for event in generator:
@ -385,20 +374,21 @@ class WorkflowService:
node_run_result = None
error = e.error
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = tenant_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1
workflow_node_execution.node_id = node_id
workflow_node_execution.node_type = node_instance.node_type
workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
# Create a NodeExecution domain model
node_execution = NodeExecution(
id=str(uuid4()),
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,
node_id=node_id,
node_type=node_instance.node_type,
title=node_instance.node_data.title,
elapsed_time=time.perf_counter() - start_at,
created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None),
)
if run_succeeded and node_run_result:
# create workflow node execution
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
@ -407,23 +397,23 @@ class WorkflowService:
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
workflow_node_execution.inputs = json.dumps(inputs)
workflow_node_execution.process_data = json.dumps(process_data)
workflow_node_execution.outputs = json.dumps(outputs)
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
node_execution.inputs = inputs
node_execution.process_data = process_data
node_execution.outputs = outputs
node_execution.metadata = node_run_result.metadata
return workflow_node_execution
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
node_execution.status = NodeExecutionStatus.SUCCEEDED
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
node_execution.status = NodeExecutionStatus.EXCEPTION
node_execution.error = node_run_result.error
else:
# Set failed status and error
node_execution.status = NodeExecutionStatus.FAILED
node_execution.error = error
return node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
"""
@ -518,11 +508,11 @@ class WorkflowService:
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
app_stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(app_stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version