Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/iteration/iteration_node.py
#	api/core/workflow/workflow_engine_manager.py
This commit is contained in:
takatost
2024-07-31 02:25:31 +08:00
184 changed files with 3427 additions and 930 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import Variable, factory
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Variable]] = Field(
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
@ -76,15 +76,15 @@ class VariablePool(BaseModel):
if value is None:
return
if not isinstance(value, Variable):
v = factory.build_anonymous_variable(value)
else:
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.

View File

@ -65,7 +65,7 @@ class BaseNode(ABC):
yield from result
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
def extract_variable_selector_to_variable_mapping(cls, config: dict):
"""
Extract variable selector to variable mapping
:param config: node config
@ -75,14 +75,13 @@ class BaseNode(ABC):
return cls._extract_variable_selector_to_variable_mapping(node_data)
@classmethod
@abstractmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
raise NotImplementedError
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:

View File

@ -260,6 +260,7 @@ class IterationNode(BaseNode):
},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -271,6 +272,61 @@ class IterationNode(BaseNode):
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.add((self.node_id, 'item'), iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
if isinstance(output, list):
state.outputs.extend(output)
else:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:

View File

@ -46,8 +46,8 @@ class MultipleRetrievalConfig(BaseModel):
score_threshold: Optional[float] = None
reranking_mode: str = 'reranking_model'
reranking_enable: bool = True
reranking_model: RerankingModelConfig
weights: WeightedScoreConfig
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):

View File

@ -139,8 +139,8 @@ class KnowledgeRetrievalNode(BaseNode):
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
reranking_model = {
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider,
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model
}
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':

View File

@ -125,11 +125,16 @@ class ToolNode(BaseNode):
]
else:
tool_input = node_data.tool_parameters[parameter_name]
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
result[parameter_name] = segment_group.log if for_log else segment_group.text
if tool_input.type == 'variable':
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
parameter_value = segment_group.log if for_log else segment_group.text
result[parameter_name] = parameter_value
return result