This commit is contained in:
takatost
2024-07-23 00:10:23 +08:00
parent a603e01f5e
commit 2c695ded79
9 changed files with 140 additions and 127 deletions

View File

@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(
variable_selector=value_selector
value_selector
)
if value:

View File

@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, cast
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
remove target node ids until merge
"""
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
if not value_selector:
break
value = self.variable_pool.get_variable_value(
variable_selector=value_selector
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(