Compare commits

..

3 Commits

17 changed files with 113 additions and 488 deletions

View File

@ -10,7 +10,6 @@ from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -62,30 +61,31 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account = db.session.query(Account).where(Account.email == email).one_or_none()
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@ -100,21 +100,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account = db.session.query(Account).where(Account.email == email).one_or_none()
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command(
@ -138,24 +139,25 @@ def reset_encrypt_key_pair():
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
)
@click.command("vdb-migrate", help="Migrate vector db.")
@ -180,15 +182,14 @@ def migrate_annotation_vector_database():
try:
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
apps = (
db.session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps:
break
except SQLAlchemyError:
@ -202,27 +203,26 @@ def migrate_annotation_vector_database():
)
try:
click.echo(f"Creating app annotation index: {app.id}")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,

View File

@ -18,18 +18,3 @@ class EnterpriseFeatureConfig(BaseSettings):
description="Allow customization of the enterprise logo.",
default=False,
)
UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN: str = Field(
description="Token for uploading knowledge pipeline template.",
default="",
)
KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT: str = Field(
description="Knowledge pipeline template copyright.",
default="Copyright 2023 Dify",
)
KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY: str = Field(
description="Knowledge pipeline template privacy policy.",
default="https://dify.ai",
)

View File

@ -1,7 +1,6 @@
from datetime import datetime
import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@ -71,7 +70,7 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
query = sa.select(Conversation).where(
query = db.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
@ -237,7 +236,7 @@ class ChatConversationApi(Resource):
.subquery()
)
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"

View File

@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@ -212,13 +211,13 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id)
.subquery()
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":

View File

@ -14,10 +14,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import (
PipelineBuiltInTemplateEntity,
PipelineTemplateInfoEntity,
)
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
@ -29,6 +26,12 @@ def _validate_name(name):
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@ -143,186 +146,6 @@ class PublishCustomizedPipelineTemplateApi(Resource):
return {"result": "success"}
class PipelineTemplateInstallApi(Resource):
"""API endpoint for installing built-in pipeline templates"""
def post(self):
"""
Install a built-in pipeline template
Args:
template_id: The template ID from URL parameter
Returns:
Success response or error with appropriate HTTP status
"""
try:
# Extract and validate Bearer token
auth_token = self._extract_bearer_token()
# Parse and validate request parameters
template_args = self._parse_template_args()
# Process uploaded template file
file_content = self._process_template_file()
# Create template entity
pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args)
# Install the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.install_built_in_pipeline_template(
pipeline_built_in_template_entity, file_content, auth_token
)
return {"result": "success", "message": "Template installed successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template installation")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template installation")
return {"error": "An unexpected error occurred during template installation"}, 500
def _extract_bearer_token(self) -> str:
"""
Extract and validate Bearer token from Authorization header
Returns:
The extracted token string
Raises:
ValueError: If token is missing or invalid
"""
auth_header = request.headers.get("Authorization", "").strip()
if not auth_header:
raise ValueError("Authorization header is required")
if not auth_header.startswith("Bearer "):
raise ValueError("Authorization header must start with 'Bearer '")
token_parts = auth_header.split(" ", 1)
if len(token_parts) != 2:
raise ValueError("Invalid Authorization header format")
auth_token = token_parts[1].strip()
if not auth_token:
raise ValueError("Bearer token cannot be empty")
return auth_token
def _parse_template_args(self) -> dict:
"""
Parse and validate template arguments from form data
Args:
template_id: The template ID from URL
Returns:
Dictionary of validated template arguments
"""
# Use reqparse for consistent parameter parsing
parser = reqparse.RequestParser()
parser.add_argument(
"template_id",
type=str,
location="form",
required=False,
help="Template ID for updating existing template"
)
parser.add_argument(
"language",
type=str,
location="form",
required=True,
default="en-US",
choices=["en-US", "zh-CN", "ja-JP"],
help="Template language code"
)
parser.add_argument(
"name",
type=str,
location="form",
required=True,
default="New Pipeline Template",
help="Template name (1-200 characters)"
)
parser.add_argument(
"description",
type=str,
location="form",
required=False,
default="",
help="Template description (max 1000 characters)"
)
args = parser.parse_args()
# Additional validation
if args.get("name"):
args["name"] = self._validate_name(args["name"])
if args.get("description") and len(args["description"]) > 1000:
raise ValueError("Description must not exceed 1000 characters")
# Filter out None values
return {k: v for k, v in args.items() if v is not None}
def _validate_name(self, name: str) -> str:
"""
Validate template name
Args:
name: Template name to validate
Returns:
Validated and trimmed name
Raises:
ValueError: If name is invalid
"""
name = name.strip()
if not name or len(name) < 1 or len(name) > 200:
raise ValueError("Template name must be between 1 and 200 characters")
return name
def _process_template_file(self) -> str:
"""
Process and validate uploaded template file
Returns:
File content as string
Raises:
ValueError: If file is missing or invalid
"""
if "file" not in request.files:
raise ValueError("Template file is required")
file = request.files["file"]
# Validate file
if not file or not file.filename:
raise ValueError("No file selected")
filename = file.filename.strip()
if not filename:
raise ValueError("File name cannot be empty")
# Check file extension
if not filename.lower().endswith(".pipeline"):
raise ValueError("Template file must be a pipeline file (.pipeline)")
try:
file_content = file.read().decode("utf-8")
except UnicodeDecodeError:
raise ValueError("Template file must be valid UTF-8 text")
return file_content
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
@ -339,7 +162,3 @@ api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)
api.add_resource(
PipelineTemplateInstallApi,
"/rag/pipeline/built-in/templates/install",
)

View File

@ -1,37 +0,0 @@
"""remove-builtin-template-user
Revision ID: bf0bcbf45396
Revises: 68519ad5cd18
Create Date: 2025-09-25 16:50:32.245503
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'bf0bcbf45396'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.drop_column('updated_by')
batch_op.drop_column('created_by')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -910,7 +910,7 @@ class AppDatasetJoin(Base):
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def app(self):
@ -931,7 +931,7 @@ class DatasetQuery(Base):
source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(Base):
@ -1239,6 +1239,15 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

View File

@ -1731,7 +1731,7 @@ class MessageChain(Base):
type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(Base):
@ -1769,7 +1769,7 @@ class MessageAgentThought(Base):
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list[Any]:
@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base):
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(Base):

View File

@ -2,7 +2,6 @@ import json
import logging
from typing import TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@ -66,7 +65,7 @@ class AppService:
return None
app_models = db.paginate(
sa.select(App).where(*filters).order_by(App.created_at.desc()),
db.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
error_out=False,

View File

@ -115,12 +115,12 @@ class DatasetService:
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where(
sa.or_(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids),
),
@ -128,9 +128,9 @@ class DatasetService:
)
else:
query = query.where(
sa.or_(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
)
@ -1879,7 +1879,7 @@ class DocumentService:
# for notion_info in notion_info_list:
# workspace_id = notion_info.workspace_id
# data_source_binding = DataSourceOauthBinding.query.filter(
# sa.and_(
# db.and_(
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
# DataSourceOauthBinding.provider == "notion",
# DataSourceOauthBinding.disabled == False,

View File

@ -128,10 +128,3 @@ class KnowledgeConfiguration(BaseModel):
if v is None:
return ""
return v
class PipelineBuiltInTemplateEntity(BaseModel):
template_id: str | None = None
name: str
description: str
language: str

View File

@ -471,7 +471,7 @@ class PluginMigration:
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break

View File

@ -74,4 +74,5 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
"created_by": pipeline_template.created_user_name,
}

View File

@ -8,7 +8,6 @@ from datetime import UTC, datetime
from typing import Any, Union, cast
from uuid import uuid4
import yaml
from flask_login import current_user
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, sessionmaker
@ -61,7 +60,6 @@ from models.dataset import ( # type: ignore
Document,
DocumentPipelineExecutionLog,
Pipeline,
PipelineBuiltInTemplate,
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
@ -78,7 +76,6 @@ from repositories.factory import DifyAPIRepositoryFactory
from services.datasource_provider_service import DatasourceProviderService
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineBuiltInTemplateEntity,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError
@ -1457,140 +1454,3 @@ class RagPipelineService:
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline
def install_built_in_pipeline_template(
self, args: PipelineBuiltInTemplateEntity, file_content: str, auth_token: str
) -> None:
"""
Install built-in pipeline template
Args:
args: Pipeline built-in template entity with template metadata
file_content: YAML content of the pipeline template
auth_token: Authentication token for authorization
Raises:
ValueError: If validation fails or template processing errors occur
"""
# Validate authentication
self._validate_auth_token(auth_token)
# Parse and validate template content
pipeline_template_dsl = self._parse_template_content(file_content)
# Extract template metadata
icon = self._extract_icon_metadata(pipeline_template_dsl)
chunk_structure = self._extract_chunk_structure(pipeline_template_dsl)
# Prepare template data
template_data = {
"name": args.name,
"description": args.description,
"chunk_structure": chunk_structure,
"icon": icon,
"language": args.language,
"yaml_content": file_content,
}
# Use transaction for database operations
try:
if args.template_id:
self._update_existing_template(args.template_id, template_data)
else:
self._create_new_template(template_data)
db.session.commit()
except Exception as e:
db.session.rollback()
raise ValueError(f"Failed to install pipeline template: {str(e)}")
def _validate_auth_token(self, auth_token: str) -> None:
"""Validate the authentication token"""
config_auth_token = dify_config.UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN
if not config_auth_token:
raise ValueError("Auth token configuration is required")
if config_auth_token != auth_token:
raise ValueError("Auth token is incorrect")
def _parse_template_content(self, file_content: str) -> dict:
"""Parse and validate YAML template content"""
try:
pipeline_template_dsl = yaml.safe_load(file_content)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML content: {str(e)}")
if not pipeline_template_dsl:
raise ValueError("Pipeline template DSL is required")
return pipeline_template_dsl
def _extract_icon_metadata(self, pipeline_template_dsl: dict) -> dict:
"""Extract icon metadata from template DSL"""
rag_pipeline_info = pipeline_template_dsl.get("rag_pipeline", {})
return {
"icon": rag_pipeline_info.get("icon", "📙"),
"icon_type": rag_pipeline_info.get("icon_type", "emoji"),
"icon_background": rag_pipeline_info.get("icon_background", "#FFEAD5"),
"icon_url": rag_pipeline_info.get("icon_url"),
}
def _extract_chunk_structure(self, pipeline_template_dsl: dict) -> str:
"""Extract chunk structure from template DSL"""
nodes = pipeline_template_dsl.get("workflow", {}).get("graph", {}).get("nodes", [])
# Use generator expression for efficiency
chunk_structure = next(
(
node.get("data", {}).get("chunk_structure")
for node in nodes
if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_INDEX.value
),
None
)
if not chunk_structure:
raise ValueError("Chunk structure is required in template")
return chunk_structure
def _update_existing_template(self, template_id: str, template_data: dict) -> None:
"""Update an existing pipeline template"""
pipeline_built_in_template = (
db.session.query(PipelineBuiltInTemplate)
.filter(PipelineBuiltInTemplate.id == template_id)
.first()
)
if not pipeline_built_in_template:
raise ValueError(f"Pipeline built-in template not found: {template_id}")
# Update template fields
for key, value in template_data.items():
setattr(pipeline_built_in_template, key, value)
db.session.add(pipeline_built_in_template)
def _create_new_template(self, template_data: dict) -> None:
"""Create a new pipeline template"""
# Get the next available position
position = self._get_next_position(template_data["language"])
# Add additional fields for new template
template_data.update({
"position": position,
"install_count": 0,
"copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT,
"privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY,
})
new_template = PipelineBuiltInTemplate(**template_data)
db.session.add(new_template)
def _get_next_position(self, language: str) -> int:
"""Get the next available position for a template in the specified language"""
max_position = (
db.session.query(func.max(PipelineBuiltInTemplate.position))
.filter(PipelineBuiltInTemplate.language == language)
.scalar()
)
return (max_position or 0) + 1

View File

@ -1,6 +1,5 @@
import uuid
import sqlalchemy as sa
from flask_login import current_user
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
@ -19,7 +18,7 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

View File

@ -2,7 +2,6 @@ import logging
import time
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select
@ -52,7 +51,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,

View File

@ -53,7 +53,7 @@ const SearchInput: FC<SearchInputProps> = ({
}}
onCompositionEnd={(e) => {
isComposing.current = false
onChange(e.currentTarget.value)
onChange(e.data)
}}
onFocus={() => setFocus(true)}
onBlur={() => setFocus(false)}