Merge branch main into feat/rag-2

This commit is contained in:
twwu
2025-07-24 17:40:04 +08:00
608 changed files with 8175 additions and 3026 deletions

View File

@ -309,7 +309,7 @@ class AgentNode(BaseNode):
}
)
value = tool_value
if parameter.type == "model-selector":
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config

View File

@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode):
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode):
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@ -415,7 +415,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(Document).filter(
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
@ -462,7 +462,7 @@ class KnowledgeRetrievalNode(BaseNode):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type == "number": # type: ignore
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode):
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
document_query = document_query.filter(and_(*filters))
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")

View File

@ -184,11 +184,10 @@ class ListOperatorNode(BaseNode):
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
value -= 1
if len(variable.value) > int(value):
result = variable.value[value]
else:
result = ""
result = variable.value[value]
return variable.model_copy(update={"value": [result]})

View File

@ -565,7 +565,7 @@ class LLMNode(BaseNode):
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict):
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]

View File

@ -54,7 +54,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
raise ValueError("value must be a string, int, float, or bool")
return typ

View File

@ -316,7 +316,14 @@ class ToolNode(BaseNode):
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise ToolNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)