mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/workflow/entities/variable_pool.py # api/core/workflow/nodes/iteration/iteration_node.py # api/core/workflow/workflow_engine_manager.py
This commit is contained in:
@ -5,7 +5,7 @@ from typing import Any, Union
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Variable, factory
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Variable]] = Field(
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
description='Variables mapping',
|
||||
default=defaultdict(dict)
|
||||
)
|
||||
@ -76,15 +76,15 @@ class VariablePool(BaseModel):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
if not isinstance(value, Variable):
|
||||
v = factory.build_anonymous_variable(value)
|
||||
else:
|
||||
if isinstance(value, Segment):
|
||||
v = value
|
||||
else:
|
||||
v = factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class BaseNode(ABC):
|
||||
yield from result
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict):
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param config: node config
|
||||
@ -75,14 +75,13 @@ class BaseNode(ABC):
|
||||
return cls._extract_variable_selector_to_variable_mapping(node_data)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
|
||||
@ -260,6 +260,7 @@ class IterationNode(BaseNode):
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -271,6 +272,61 @@ class IterationNode(BaseNode):
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
|
||||
if isinstance(output, list):
|
||||
state.outputs.extend(output)
|
||||
else:
|
||||
state.outputs.append(output)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
|
||||
@ -46,8 +46,8 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_mode: str = 'reranking_model'
|
||||
reranking_enable: bool = True
|
||||
reranking_model: RerankingModelConfig
|
||||
weights: WeightedScoreConfig
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
||||
@ -139,8 +139,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
|
||||
reranking_model = {
|
||||
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
|
||||
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
|
||||
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model
|
||||
}
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
|
||||
|
||||
@ -125,11 +125,16 @@ class ToolNode(BaseNode):
|
||||
]
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
result[parameter_name] = segment_group.log if for_log else segment_group.text
|
||||
if tool_input.type == 'variable':
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
else:
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
0
api/core/workflow/workflow_engine_manager.py
Normal file
0
api/core/workflow/workflow_engine_manager.py
Normal file
Reference in New Issue
Block a user