mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
r2
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
from calendar import day_abbr
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
@ -7,7 +6,7 @@ import random
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
@ -282,7 +281,6 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
@ -494,10 +492,9 @@ class DatasetService:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(session: Session,
|
||||
dataset: Dataset,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
has_published: bool = False):
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
):
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
@ -616,7 +613,6 @@ class DatasetService:
|
||||
if action:
|
||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@ -22,11 +21,9 @@ class DatasourceProviderService:
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def datasource_provider_credentials_validate(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credentials: dict) -> None:
|
||||
def datasource_provider_credentials_validate(
|
||||
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
validate datasource provider credentials.
|
||||
|
||||
@ -34,29 +31,30 @@ class DatasourceProviderService:
|
||||
:param provider:
|
||||
:param credentials:
|
||||
"""
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
credentials=credentials)
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
|
||||
)
|
||||
if credential_valid:
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
if not datasource_provider:
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
datasource_provider = DatasourceProvider(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
auth_type="api_key",
|
||||
encrypted_credentials=credentials)
|
||||
datasource_provider = DatasourceProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
auth_type="api_key",
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -101,11 +99,15 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id
|
||||
).all()
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
@ -128,10 +130,7 @@ class DatasourceProviderService:
|
||||
|
||||
return copy_credentials_list
|
||||
|
||||
def remove_datasource_credentials(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str) -> None:
|
||||
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
remove datasource credentials.
|
||||
|
||||
@ -140,9 +139,11 @@ class DatasourceProviderService:
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
if datasource_provider:
|
||||
db.session.delete(datasource_provider)
|
||||
db.session.commit()
|
||||
|
||||
@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
|
||||
"""
|
||||
Knowledge Base Configuration.
|
||||
"""
|
||||
|
||||
chunk_structure: str
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_model_provider: Optional[str] = ""
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, EndUser
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
)
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_customized_template.id,
|
||||
"name": pipeline_customized_template.name,
|
||||
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
"name": pipeline_built_in_template.name,
|
||||
|
||||
@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
|
||||
|
||||
class RagPipelineService:
|
||||
@classmethod
|
||||
def get_pipeline_templates(
|
||||
cls, type: str = "built-in", language: str = "en-US"
|
||||
) -> dict:
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
@ -308,7 +306,7 @@ class RagPipelineService:
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
knowledge_configuration=knowledge_configuration,
|
||||
has_published=pipeline.is_published
|
||||
has_published=pipeline.is_published,
|
||||
)
|
||||
# return new workflow
|
||||
return workflow
|
||||
@ -444,12 +442,10 @@ class RagPipelineService:
|
||||
)
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = (
|
||||
datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return {
|
||||
"result": [page.model_dump() for page in online_document_result.result],
|
||||
@ -470,7 +466,6 @@ class RagPipelineService:
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
@ -689,8 +684,8 @@ class RagPipelineService:
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
or_(
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
|
||||
)
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
),
|
||||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
@ -763,18 +758,17 @@ class RagPipelineService:
|
||||
|
||||
# Use the repository to get the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
app_id=pipeline.id,
|
||||
user=user,
|
||||
triggered_from=None
|
||||
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
|
||||
# Convert domain models to database models
|
||||
node_executions = repository.get_by_workflow_run(
|
||||
workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return workflow_node_executions
|
||||
|
||||
@ -279,7 +279,11 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
if (
|
||||
dataset
|
||||
and pipeline.is_published
|
||||
and dataset.chunk_structure != knowledge_configuration.chunk_structure
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
else:
|
||||
dataset = Dataset(
|
||||
@ -304,8 +308,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -323,12 +326,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -443,8 +442,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -462,12 +460,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -538,7 +532,6 @@ class RagPipelineDslService:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
@ -554,7 +547,6 @@ class RagPipelineDslService:
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
|
||||
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
@ -576,7 +568,6 @@ class RagPipelineDslService:
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
@ -636,7 +627,6 @@ class RagPipelineDslService:
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
@ -874,7 +864,6 @@ class RagPipelineDslService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
@ -886,9 +875,7 @@ class RagPipelineDslService:
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
|
||||
@ -12,12 +12,12 @@ class RagPipelineManageService:
|
||||
|
||||
# get all builtin providers
|
||||
manager = PluginDatasourceManager()
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
for datasource in datasources:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
|
||||
provider=datasource.provider,
|
||||
plugin_id=datasource.plugin_id)
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
return datasources
|
||||
|
||||
Reference in New Issue
Block a user