mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
r2
This commit is contained in:
@ -3,7 +3,6 @@ from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, EndUser
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
)
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_customized_template.id,
|
||||
"name": pipeline_customized_template.name,
|
||||
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
"name": pipeline_built_in_template.name,
|
||||
|
||||
@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
|
||||
|
||||
class RagPipelineService:
|
||||
@classmethod
|
||||
def get_pipeline_templates(
|
||||
cls, type: str = "built-in", language: str = "en-US"
|
||||
) -> dict:
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
@ -308,7 +306,7 @@ class RagPipelineService:
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
knowledge_configuration=knowledge_configuration,
|
||||
has_published=pipeline.is_published
|
||||
has_published=pipeline.is_published,
|
||||
)
|
||||
# return new workflow
|
||||
return workflow
|
||||
@ -444,12 +442,10 @@ class RagPipelineService:
|
||||
)
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = (
|
||||
datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return {
|
||||
"result": [page.model_dump() for page in online_document_result.result],
|
||||
@ -470,7 +466,6 @@ class RagPipelineService:
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
@ -689,8 +684,8 @@ class RagPipelineService:
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
or_(
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
|
||||
)
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
),
|
||||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
@ -763,18 +758,17 @@ class RagPipelineService:
|
||||
|
||||
# Use the repository to get the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
app_id=pipeline.id,
|
||||
user=user,
|
||||
triggered_from=None
|
||||
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
|
||||
# Convert domain models to database models
|
||||
node_executions = repository.get_by_workflow_run(
|
||||
workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return workflow_node_executions
|
||||
|
||||
@ -279,7 +279,11 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
if (
|
||||
dataset
|
||||
and pipeline.is_published
|
||||
and dataset.chunk_structure != knowledge_configuration.chunk_structure
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
else:
|
||||
dataset = Dataset(
|
||||
@ -304,8 +308,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -323,12 +326,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -443,8 +442,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -462,12 +460,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -538,7 +532,6 @@ class RagPipelineDslService:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
@ -554,7 +547,6 @@ class RagPipelineDslService:
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
|
||||
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
@ -576,7 +568,6 @@ class RagPipelineDslService:
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
@ -636,7 +627,6 @@ class RagPipelineDslService:
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
@ -874,7 +864,6 @@ class RagPipelineDslService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
@ -886,9 +875,7 @@ class RagPipelineDslService:
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
|
||||
@ -12,12 +12,12 @@ class RagPipelineManageService:
|
||||
|
||||
# get all builtin providers
|
||||
manager = PluginDatasourceManager()
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
for datasource in datasources:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
|
||||
provider=datasource.provider,
|
||||
plugin_id=datasource.plugin_id)
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
return datasources
|
||||
|
||||
Reference in New Issue
Block a user