Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Novice Lee
2025-01-09 08:48:00 +08:00
231 changed files with 3564 additions and 2382 deletions

View File

@ -67,7 +67,7 @@ class TokenPair(BaseModel):
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
@ -921,6 +921,9 @@ class RegisterService:
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
) -> str:
if not inviter:
raise ValueError("Inviter is required")
"""Invite new member"""
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()

View File

@ -2,6 +2,7 @@ import logging
import uuid
from enum import StrEnum
from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
@ -124,7 +125,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: bytes | str = b""
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
@ -133,13 +134,17 @@ class AppDslService:
error="yaml_url is required when import_mode is yaml-url",
)
try:
# tricky way to handle url from github to github raw url
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return Import(

View File

@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:return:
@ -44,6 +45,8 @@ class AppService:
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))

View File

@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()
@staticmethod

View File

@ -86,7 +86,7 @@ class DatasetService:
else:
return [], 0
else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
@ -382,7 +382,7 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
@ -404,7 +404,7 @@ class DatasetService:
if not user:
raise ValueError("User not found")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@ -434,6 +434,12 @@ class DatasetService:
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
@ -786,13 +792,19 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
@ -804,7 +816,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:

View File

@ -27,7 +27,7 @@ class WorkflowAppService:
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%"
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),

View File

@ -298,7 +298,7 @@ class WorkflowService:
start_at: float,
tenant_id: str,
node_id: str,
):
) -> WorkflowNodeExecution:
"""
Handle node run result