mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/memory/token_buffer_memory.py # api/core/rag/extractor/notion_extractor.py # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/variables/variables.py # api/core/workflow/graph/graph.py # api/core/workflow/graph_engine/entities/event.py # api/services/dataset_service.py # web/app/components/app-sidebar/index.tsx # web/app/components/base/tag-management/selector.tsx # web/app/components/base/toast/index.tsx # web/app/components/datasets/create/website/index.tsx # web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx # web/app/components/workflow/header/version-history-button.tsx # web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts # web/app/components/workflow/hooks/use-workflow-interactions.ts # web/app/components/workflow/panel/version-history-panel/index.tsx # web/service/base.ts
This commit is contained in:
@ -7,9 +7,8 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -65,7 +64,7 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -135,7 +134,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -240,15 +240,12 @@ class DatasetRetrieval:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
@ -327,7 +324,8 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
@ -514,24 +512,20 @@ class DatasetRetrieval:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
@ -539,7 +533,6 @@ class DatasetRetrieval:
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
@ -599,8 +592,8 @@ class DatasetRetrieval:
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@ -647,7 +640,7 @@ class DatasetRetrieval:
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
@ -685,7 +678,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -743,7 +737,7 @@ class DatasetRetrieval:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
top_k=retrieve_config.top_k or 4,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
@ -958,7 +952,8 @@ class DatasetRetrieval:
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(metadata_stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
@ -990,7 +985,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
@ -1005,7 +1000,7 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
|
||||
@ -19,5 +19,5 @@ class StructuredChatOutputParser:
|
||||
return ReactAction(response["action"], response.get("action_input", {}), text)
|
||||
else:
|
||||
return ReactFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(f"Could not parse LLM output: {text}")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
@ -28,18 +28,15 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
),
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
@ -77,7 +77,7 @@ class ReactMultiDatasetRouter:
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
result_text, _ = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -150,15 +150,12 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = cast(
|
||||
Generator[LLMResult, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
Reference in New Issue
Block a user