mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'main' into feat/mcp
This commit is contained in:
@ -49,7 +49,7 @@ from services.errors.account import (
|
||||
RoleAlreadyAssignedError,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.delete_account_task import delete_account_task
|
||||
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
|
||||
@ -108,17 +108,20 @@ class AccountService:
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if current_tenant:
|
||||
account.current_tenant_id = current_tenant.tenant_id
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
|
||||
account.current_tenant_id = available_ta.tenant_id
|
||||
account.set_tenant_id(available_ta.tenant_id)
|
||||
available_ta.current = True
|
||||
db.session.commit()
|
||||
|
||||
@ -297,9 +300,9 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
||||
account_id=account.id, provider=provider
|
||||
).first()
|
||||
account_integrate: Optional[AccountIntegrate] = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
)
|
||||
|
||||
if account_integrate:
|
||||
# If it exists, update the record
|
||||
@ -612,7 +615,10 @@ class TenantService:
|
||||
):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
@ -622,6 +628,10 @@ class TenantService:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
raise WorkspacesLimitExceededError()
|
||||
|
||||
if name:
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
||||
else:
|
||||
@ -666,7 +676,7 @@ class TenantService:
|
||||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
@ -695,12 +705,12 @@ class TenantService:
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
TenantAccountJoin.query.filter(
|
||||
db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@ -787,7 +797,7 @@ class TenantService:
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
@ -800,7 +810,7 @@ class TenantService:
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
@ -812,15 +822,23 @@ class TenantService:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
target_member_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
)
|
||||
|
||||
if not target_member_join:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
if target_member_join.role == new_role:
|
||||
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join.role = "admin"
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
)
|
||||
if current_owner_join:
|
||||
current_owner_join.role = "admin"
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
@ -837,7 +855,7 @@ class TenantService:
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return cast(dict, tenant.custom_config_dict)
|
||||
|
||||
@ -914,7 +932,11 @@ class RegisterService:
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
||||
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
|
||||
if (
|
||||
FeatureService.get_system_features().is_allow_create_workspace
|
||||
and create_workspace_required
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
@ -959,7 +981,7 @@ class RegisterService:
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -124,8 +124,9 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike("%{}%".format(keyword)),
|
||||
@ -133,14 +134,14 @@ class AppAnnotationService:
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
else:
|
||||
annotations = (
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
@ -325,13 +326,16 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (
|
||||
AppAnnotationHitHistory.query.filter(
|
||||
stmt = (
|
||||
select(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
annotation_hit_histories = db.paginate(
|
||||
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.2.0"
|
||||
CURRENT_DSL_VERSION = "0.3.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
|
||||
@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.tag_service import TagService
|
||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||
|
||||
@ -155,6 +157,10 @@ class AppService:
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
|
||||
|
||||
return app
|
||||
|
||||
def get_app(self, app: App) -> App:
|
||||
@ -307,6 +313,10 @@ class AppService:
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
# clean up web app settings
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
|
||||
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
@ -373,3 +383,15 @@ class AppService:
|
||||
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def get_app_code_by_id(app_id: str) -> str:
|
||||
"""
|
||||
Get app code by app id
|
||||
:param app_id: app id
|
||||
:return: app code
|
||||
"""
|
||||
site = db.session.query(Site).filter(Site.app_id == app_id).first()
|
||||
if not site:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
return str(site.code)
|
||||
|
||||
@ -9,7 +9,7 @@ from collections import Counter
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
dataset_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
)
|
||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||
|
||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||
@ -129,7 +131,7 @@ class DatasetService:
|
||||
else:
|
||||
return [], 0
|
||||
|
||||
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@ -153,9 +155,10 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_datasets_by_ids(ids, tenant_id):
|
||||
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
|
||||
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
|
||||
)
|
||||
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@staticmethod
|
||||
@ -174,7 +177,7 @@ class DatasetService:
|
||||
retrieval_model: Optional[RetrievalModel] = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == "high_quality":
|
||||
@ -235,7 +238,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@ -436,7 +439,7 @@ class DatasetService:
|
||||
# update Retrieval model
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
if action:
|
||||
@ -460,7 +463,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
|
||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
@ -474,15 +477,15 @@ class DatasetService:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if dataset.permission == "partial_members":
|
||||
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
if (
|
||||
not user_permission
|
||||
and dataset.tenant_id != user.current_tenant_id
|
||||
and dataset.created_by != user.id
|
||||
):
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
# For partial team permission, user needs explicit permission or be the creator
|
||||
if dataset.created_by != user.id:
|
||||
user_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
)
|
||||
if not user_permission:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
|
||||
@ -499,23 +502,24 @@ class DatasetService:
|
||||
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
|
||||
dp.dataset_id == dataset.id
|
||||
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||
):
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
||||
dataset_queries = (
|
||||
DatasetQuery.query.filter_by(dataset_id=dataset_id)
|
||||
.order_by(db.desc(DatasetQuery.created_at))
|
||||
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
)
|
||||
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
||||
|
||||
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return dataset_queries.items, dataset_queries.total
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return (
|
||||
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
|
||||
db.session.query(AppDatasetJoin)
|
||||
.filter(AppDatasetJoin.dataset_id == dataset_id)
|
||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||
.all()
|
||||
)
|
||||
@ -530,10 +534,14 @@ class DatasetService:
|
||||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
).all()
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog)
|
||||
.filter(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if dataset_auto_disable_logs:
|
||||
return {
|
||||
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
||||
@ -873,7 +881,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
document = (
|
||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
)
|
||||
if document:
|
||||
return document.position + 1
|
||||
else:
|
||||
@ -948,11 +958,11 @@ class DocumentService:
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
@ -980,7 +990,7 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
@ -1010,13 +1020,17 @@ class DocumentService:
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
@ -1054,12 +1068,16 @@ class DocumentService:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
@ -1067,14 +1085,18 @@ class DocumentService:
|
||||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
@ -1206,12 +1228,16 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
documents_count = Document.query.filter(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
).count()
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return documents_count
|
||||
|
||||
@staticmethod
|
||||
@ -1278,14 +1304,18 @@ class DocumentService:
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
@ -1328,7 +1358,7 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
update_params = {DocumentSegment.status: "re_segment"}
|
||||
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
@ -1918,7 +1948,8 @@ class SegmentService:
|
||||
@classmethod
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
index_node_ids = (
|
||||
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id)
|
||||
.filter(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
@ -2157,20 +2188,28 @@ class SegmentService:
|
||||
def get_child_chunks(
|
||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
||||
):
|
||||
query = ChildChunk.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
segment_id=segment_id,
|
||||
).order_by(ChildChunk.position.asc())
|
||||
query = (
|
||||
select(ChildChunk)
|
||||
.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
segment_id=segment_id,
|
||||
)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
||||
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@classmethod
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
||||
"""Get a child chunk by its ID."""
|
||||
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
|
||||
result = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@classmethod
|
||||
@ -2184,7 +2223,7 @@ class SegmentService:
|
||||
limit: int = 20,
|
||||
):
|
||||
"""Get segments for a document with optional filtering."""
|
||||
query = DocumentSegment.query.filter(
|
||||
query = select(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
@ -2194,9 +2233,8 @@ class SegmentService:
|
||||
if keyword:
|
||||
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
query = query.order_by(DocumentSegment.position.asc())
|
||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
return paginated_segments.items, paginated_segments.total
|
||||
|
||||
@ -2236,9 +2274,11 @@ class SegmentService:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
# check segment
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
@ -2251,9 +2291,11 @@ class SegmentService:
|
||||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||
"""Get a segment by its ID."""
|
||||
result = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
|
||||
).first()
|
||||
result = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,90 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||
default="private",
|
||||
alias="accessMode",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
@classmethod
|
||||
def get_app_web_sso_enabled(cls, app_code):
|
||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
class WebAppAuth:
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
|
||||
params = {"userId": user_id, "appCode": app_code}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||
|
||||
return data.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
params = {"appId": app_id}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
|
||||
@classmethod
|
||||
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
|
||||
if not app_ids:
|
||||
return {}
|
||||
body = {"appIds": app_ids}
|
||||
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
|
||||
if not isinstance(data["accessModes"], dict):
|
||||
raise ValueError("Invalid data format.")
|
||||
|
||||
ret = {}
|
||||
for key, value in data["accessModes"].items():
|
||||
curr = WebAppSettings()
|
||||
curr.access_mode = value
|
||||
ret[key] = curr
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||
if not app_code:
|
||||
raise ValueError("app_code must be provided.")
|
||||
params = {"appCode": app_code}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
|
||||
@classmethod
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
if access_mode not in ["public", "private", "private_all"]:
|
||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||
|
||||
data = {"appId": app_id, "accessMode": access_mode}
|
||||
|
||||
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
|
||||
|
||||
return response.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def cleanup_webapp(cls, app_id: str):
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
|
||||
body = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||
|
||||
18
api/services/enterprise/mail_service.py
Normal file
18
api/services/enterprise/mail_service.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tasks.mail_enterprise_task import send_enterprise_email_task
|
||||
|
||||
|
||||
class DifyMail(BaseModel):
|
||||
to: list[str]
|
||||
subject: str
|
||||
body: str
|
||||
substitutions: dict[str, str] = {}
|
||||
|
||||
|
||||
class EnterpriseMailService:
|
||||
@classmethod
|
||||
def send_mail(cls, mail: DifyMail):
|
||||
send_enterprise_email_task.delay(
|
||||
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
|
||||
)
|
||||
@ -7,3 +7,7 @@ class WorkSpaceNotAllowedCreateError(BaseServiceError):
|
||||
|
||||
class WorkSpaceNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class WorkspacesLimitExceededError(BaseServiceError):
|
||||
pass
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
|
||||
|
||||
class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
|
||||
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
|
||||
ExternalKnowledgeApis.created_at.desc()
|
||||
def get_external_knowledge_apis(
|
||||
page, per_page, tenant_id, search=None
|
||||
) -> tuple[list[ExternalKnowledgeApis], int | None]:
|
||||
query = (
|
||||
select(ExternalKnowledgeApis)
|
||||
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.order_by(ExternalKnowledgeApis.created_at.desc())
|
||||
)
|
||||
if search:
|
||||
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
||||
|
||||
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
external_knowledge_apis = db.paginate(
|
||||
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||
)
|
||||
|
||||
return external_knowledge_apis.items, external_knowledge_apis.total
|
||||
|
||||
@ -92,18 +99,18 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id
|
||||
).first()
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
return external_knowledge_api
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
||||
@ -120,9 +127,9 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
|
||||
@ -131,25 +138,29 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
|
||||
count = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
return False, 0
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
return external_knowledge_binding
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
settings = json.loads(external_knowledge_api.settings)
|
||||
@ -212,11 +223,13 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@ -254,15 +267,17 @@ class ExternalDatasetService:
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_binding.external_knowledge_api_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||
.first()
|
||||
)
|
||||
if not external_knowledge_api:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
@ -27,6 +27,32 @@ class LimitationModel(BaseModel):
|
||||
limit: int = 0
|
||||
|
||||
|
||||
class LicenseLimitationModel(BaseModel):
|
||||
"""
|
||||
- enabled: whether this limit is enforced
|
||||
- size: current usage count
|
||||
- limit: maximum allowed count; 0 means unlimited
|
||||
"""
|
||||
|
||||
enabled: bool = Field(False, description="Whether this limit is currently active")
|
||||
size: int = Field(0, description="Number of resources already consumed")
|
||||
limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit")
|
||||
|
||||
def is_available(self, required: int = 1) -> bool:
|
||||
"""
|
||||
Determine whether the requested amount can be allocated.
|
||||
|
||||
Returns True if:
|
||||
- this limit is not active, or
|
||||
- the limit is zero (unlimited), or
|
||||
- there is enough remaining quota.
|
||||
"""
|
||||
if not self.enabled or self.limit == 0:
|
||||
return True
|
||||
|
||||
return (self.limit - self.size) >= required
|
||||
|
||||
|
||||
class LicenseStatus(StrEnum):
|
||||
NONE = "none"
|
||||
INACTIVE = "inactive"
|
||||
@ -39,6 +65,27 @@ class LicenseStatus(StrEnum):
|
||||
class LicenseModel(BaseModel):
|
||||
status: LicenseStatus = LicenseStatus.NONE
|
||||
expired_at: str = ""
|
||||
workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
|
||||
|
||||
|
||||
class BrandingModel(BaseModel):
|
||||
enabled: bool = False
|
||||
application_title: str = ""
|
||||
login_page_logo: str = ""
|
||||
workspace_logo: str = ""
|
||||
favicon: str = ""
|
||||
|
||||
|
||||
class WebAppAuthSSOModel(BaseModel):
|
||||
protocol: str = ""
|
||||
|
||||
|
||||
class WebAppAuthModel(BaseModel):
|
||||
enabled: bool = False
|
||||
allow_sso: bool = False
|
||||
sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel()
|
||||
allow_email_code_login: bool = False
|
||||
allow_email_password_login: bool = False
|
||||
|
||||
|
||||
class FeatureModel(BaseModel):
|
||||
@ -54,6 +101,8 @@ class FeatureModel(BaseModel):
|
||||
can_replace_logo: bool = False
|
||||
model_load_balancing_enabled: bool = False
|
||||
dataset_operator_enabled: bool = False
|
||||
webapp_copyright_enabled: bool = False
|
||||
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -68,9 +117,6 @@ class KnowledgeRateLimitModel(BaseModel):
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
sso_enforced_for_web: bool = False
|
||||
sso_enforced_for_web_protocol: str = ""
|
||||
enable_web_sso_switch_component: bool = False
|
||||
enable_marketplace: bool = False
|
||||
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
enable_email_code_login: bool = False
|
||||
@ -80,6 +126,8 @@ class SystemFeatureModel(BaseModel):
|
||||
is_allow_create_workspace: bool = False
|
||||
is_email_setup: bool = False
|
||||
license: LicenseModel = LicenseModel()
|
||||
branding: BrandingModel = BrandingModel()
|
||||
webapp_auth: WebAppAuthModel = WebAppAuthModel()
|
||||
|
||||
|
||||
class FeatureService:
|
||||
@ -92,6 +140,10 @@ class FeatureService:
|
||||
if dify_config.BILLING_ENABLED and tenant_id:
|
||||
cls._fulfill_params_from_billing_api(features, tenant_id)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
features.webapp_copyright_enabled = True
|
||||
cls._fulfill_params_from_workspace_info(features, tenant_id)
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
@ -111,8 +163,8 @@ class FeatureService:
|
||||
cls._fulfill_system_params_from_env(system_features)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
system_features.enable_web_sso_switch_component = True
|
||||
|
||||
system_features.branding.enabled = True
|
||||
system_features.webapp_auth.enabled = True
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
@ -136,6 +188,14 @@ class FeatureService:
|
||||
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
|
||||
features.education.enabled = dify_config.EDUCATION_ENABLED
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str):
|
||||
workspace_info = EnterpriseService.get_workspace_info(tenant_id)
|
||||
if "WorkspaceMembers" in workspace_info:
|
||||
features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"]
|
||||
features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"]
|
||||
features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
@ -145,6 +205,9 @@ class FeatureService:
|
||||
features.billing.subscription.interval = billing_info["subscription"]["interval"]
|
||||
features.education.activated = billing_info["subscription"].get("education", False)
|
||||
|
||||
if features.billing.subscription.plan != "sandbox":
|
||||
features.webapp_copyright_enabled = True
|
||||
|
||||
if "members" in billing_info:
|
||||
features.members.size = billing_info["members"]["size"]
|
||||
features.members.limit = billing_info["members"]["limit"]
|
||||
@ -178,38 +241,53 @@ class FeatureService:
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
if "sso_enforced_for_signin" in enterprise_info:
|
||||
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
|
||||
if "SSOEnforcedForSignin" in enterprise_info:
|
||||
features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"]
|
||||
|
||||
if "sso_enforced_for_signin_protocol" in enterprise_info:
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
|
||||
if "SSOEnforcedForSigninProtocol" in enterprise_info:
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"]
|
||||
|
||||
if "sso_enforced_for_web" in enterprise_info:
|
||||
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
|
||||
if "EnableEmailCodeLogin" in enterprise_info:
|
||||
features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"]
|
||||
|
||||
if "sso_enforced_for_web_protocol" in enterprise_info:
|
||||
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
|
||||
if "EnableEmailPasswordLogin" in enterprise_info:
|
||||
features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"]
|
||||
|
||||
if "enable_email_code_login" in enterprise_info:
|
||||
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
|
||||
if "IsAllowRegister" in enterprise_info:
|
||||
features.is_allow_register = enterprise_info["IsAllowRegister"]
|
||||
|
||||
if "enable_email_password_login" in enterprise_info:
|
||||
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
|
||||
if "IsAllowCreateWorkspace" in enterprise_info:
|
||||
features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"]
|
||||
|
||||
if "is_allow_register" in enterprise_info:
|
||||
features.is_allow_register = enterprise_info["is_allow_register"]
|
||||
if "Branding" in enterprise_info:
|
||||
features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "")
|
||||
features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "")
|
||||
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
|
||||
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
|
||||
|
||||
if "is_allow_create_workspace" in enterprise_info:
|
||||
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
|
||||
if "WebAppAuth" in enterprise_info:
|
||||
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False)
|
||||
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
|
||||
"allowEmailCodeLogin", False
|
||||
)
|
||||
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
|
||||
"allowEmailPasswordLogin", False
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if "license" in enterprise_info:
|
||||
license_info = enterprise_info["license"]
|
||||
if "License" in enterprise_info:
|
||||
license_info = enterprise_info["License"]
|
||||
|
||||
if "status" in license_info:
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
|
||||
if "expired_at" in license_info:
|
||||
features.license.expired_at = license_info["expired_at"]
|
||||
if "expiredAt" in license_info:
|
||||
features.license.expired_at = license_info["expiredAt"]
|
||||
|
||||
if "workspaces" in license_info:
|
||||
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
|
||||
features.license.workspaces.limit = license_info["workspaces"]["limit"]
|
||||
features.license.workspaces.size = license_info["workspaces"]["used"]
|
||||
|
||||
@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
|
||||
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
@ -81,7 +81,7 @@ class FileService:
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER),
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
|
||||
created_by=user.id,
|
||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
used=False,
|
||||
@ -133,7 +133,7 @@ class FileService:
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by=current_user.id,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
|
||||
@ -69,6 +69,7 @@ class HitTestingService:
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
@ -82,6 +83,7 @@ class HitTestingService:
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
@ -177,6 +177,21 @@ class MessageService:
|
||||
|
||||
return feedback
|
||||
|
||||
@classmethod
|
||||
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
|
||||
"""Get all feedbacks of an app"""
|
||||
offset = (page - 1) * limit
|
||||
feedbacks = (
|
||||
db.session.query(MessageFeedback)
|
||||
.filter(MessageFeedback.app_id == app_model.id)
|
||||
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [record.to_dict() for record in feedbacks]
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
message = (
|
||||
|
||||
@ -20,9 +20,11 @@ class MetadataService:
|
||||
@staticmethod
|
||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
|
||||
).first():
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == metadata_args.name:
|
||||
@ -42,16 +44,18 @@ class MetadataService:
|
||||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
|
||||
).first():
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == name:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
@ -60,7 +64,9 @@ class MetadataService:
|
||||
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@ -82,13 +88,15 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@ -193,7 +201,7 @@ class MetadataService:
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# deal metadata binding
|
||||
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
|
||||
for metadata_value in operation.metadata_list:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -230,9 +238,9 @@ class MetadataService:
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": DatasetMetadataBinding.query.filter_by(
|
||||
metadata_id=item.get("id"), dataset_id=dataset.id
|
||||
).count(),
|
||||
"count": db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
|
||||
.count(),
|
||||
}
|
||||
for item in dataset.doc_metadata or []
|
||||
if item.get("id") != "built-in"
|
||||
|
||||
@ -87,7 +87,9 @@ class OpsService:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
if tracing_provider not in provider_config_map and tracing_provider:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = (
|
||||
@ -150,7 +152,9 @@ class OpsService:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
if tracing_provider not in provider_config_map:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
# check if trace config already exists
|
||||
|
||||
@ -17,7 +17,7 @@ from core.plugin.entities.plugin import (
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginListResponse, PluginUploadResponse
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
@ -110,6 +110,15 @@ class PluginService:
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins_with_total(tenant_id, page, page_size)
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -12,21 +13,30 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorService:
|
||||
@classmethod
|
||||
def create_segments_vector(
|
||||
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
):
|
||||
documents = []
|
||||
documents: list[Document] = []
|
||||
|
||||
for segment in segments:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||
segment.document_id,
|
||||
segment.id,
|
||||
)
|
||||
continue
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
@ -50,9 +60,11 @@ class VectorService:
|
||||
)
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
|
||||
cls.generate_child_chunks(
|
||||
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
|
||||
)
|
||||
else:
|
||||
document = Document(
|
||||
rag_document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
@ -61,7 +73,7 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
documents.append(rag_document)
|
||||
if len(documents) > 0:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||
|
||||
141
api/services/webapp_auth_service.py
Normal file
141
api/services/webapp_auth_service.py
Normal file
@ -0,0 +1,141 @@
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TokenManager
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password
|
||||
from models.account import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||
|
||||
|
||||
class WebAppAuthService:
|
||||
"""Service for web app authentication."""
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
raise AccountLoginError("Account is banned.")
|
||||
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
return cast(Account, account)
|
||||
|
||||
@classmethod
|
||||
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not site:
|
||||
raise NotFound("Site not found.")
|
||||
|
||||
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
|
||||
|
||||
return access_token
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = db.session.query(Account).filter(Account.email == email).first()
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def send_email_code_login_email(
|
||||
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
||||
):
|
||||
email = account.email if account else email
|
||||
if email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
|
||||
)
|
||||
send_email_code_login_mail_task.delay(
|
||||
language=language,
|
||||
to=account.email if account else email,
|
||||
code=code,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, "webapp_email_code_login")
|
||||
|
||||
@classmethod
|
||||
def revoke_email_code_login_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "webapp_email_code_login")
|
||||
|
||||
@classmethod
|
||||
def create_end_user(cls, app_code, email) -> EndUser:
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not site:
|
||||
raise NotFound("Site not found.")
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound("App not found.")
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=False,
|
||||
session_id=email,
|
||||
name="enterpriseuser",
|
||||
external_user_id="enterpriseuser",
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@classmethod
|
||||
def _validate_user_accessibility(cls, account: Account, app_code: str):
|
||||
"""Check if the user is allowed to access the app."""
|
||||
system_features = FeatureService.get_system_features()
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
|
||||
if (
|
||||
app_settings.access_mode != "public"
|
||||
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
|
||||
):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
@classmethod
|
||||
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
|
||||
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
|
||||
exp = int(exp_dt.timestamp())
|
||||
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"user_id": account.id,
|
||||
"end_user_id": end_user_id,
|
||||
"token_source": "webapp",
|
||||
"exp": exp,
|
||||
}
|
||||
|
||||
token: str = PassportService().issue(payload)
|
||||
return token
|
||||
@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import App, EndUser, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowRunStatus
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ class WorkflowAppService:
|
||||
|
||||
stmt = stmt.outerjoin(
|
||||
EndUser,
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER),
|
||||
).where(or_(*keyword_conditions))
|
||||
|
||||
if status:
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App
|
||||
from models.workflow import (
|
||||
from models import (
|
||||
Account,
|
||||
App,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class WorkflowRunService:
|
||||
@ -116,7 +120,12 @@ class WorkflowRunService:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
|
||||
def get_workflow_run_node_executions(
|
||||
self,
|
||||
app_model: App,
|
||||
run_id: str,
|
||||
user: Account | EndUser,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
@ -128,17 +137,17 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
# Use the repository to get the node executions
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
# Use the repository to get the database models directly
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
workflow_node_executions = repository.get_db_models_by_workflow_run(
|
||||
workflow_run_id=run_id, order_config=order_config
|
||||
)
|
||||
|
||||
return list(node_executions)
|
||||
return workflow_node_executions
|
||||
|
||||
@ -10,9 +10,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -21,12 +22,10 @@ from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
@ -268,37 +267,37 @@ class WorkflowService:
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
node_execution = self._handle_node_run_result(
|
||||
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=app_model.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
# Set workflow_id on the NodeExecution
|
||||
node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
# Use the repository to save the workflow node execution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
# Create repository and save the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
user=account,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
repository.save(workflow_node_execution)
|
||||
repository.save(node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution = repository.to_db_model(node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
) -> NodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
@ -306,7 +305,7 @@ class WorkflowService:
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
@ -314,7 +313,6 @@ class WorkflowService:
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@ -322,21 +320,12 @@ class WorkflowService:
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
) -> NodeExecution:
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
node_instance, generator = invoke_node_fn()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
@ -385,20 +374,21 @@ class WorkflowService:
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
workflow_node_execution.node_type = node_instance.node_type
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = NodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type,
|
||||
title=node_instance.node_data.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
if run_succeeded and node_run_result:
|
||||
# create workflow node execution
|
||||
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
|
||||
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||
process_data = (
|
||||
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
||||
@ -407,23 +397,23 @@ class WorkflowService:
|
||||
)
|
||||
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
|
||||
|
||||
workflow_node_execution.inputs = json.dumps(inputs)
|
||||
workflow_node_execution.process_data = json.dumps(process_data)
|
||||
workflow_node_execution.outputs = json.dumps(outputs)
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
)
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
workflow_node_execution.error = node_run_result.error
|
||||
else:
|
||||
# create workflow node execution
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
node_execution.inputs = inputs
|
||||
node_execution.process_data = process_data
|
||||
node_execution.outputs = outputs
|
||||
node_execution.metadata = node_run_result.metadata
|
||||
|
||||
return workflow_node_execution
|
||||
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
node_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
node_execution.status = NodeExecutionStatus.EXCEPTION
|
||||
node_execution.error = node_run_result.error
|
||||
else:
|
||||
# Set failed status and error
|
||||
node_execution.status = NodeExecutionStatus.FAILED
|
||||
node_execution.error = error
|
||||
|
||||
return node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
"""
|
||||
@ -518,11 +508,11 @@ class WorkflowService:
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
app_stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
|
||||
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
|
||||
Reference in New Issue
Block a user