feat(workflow): integrate workflow entry with advanced chat app

This commit is contained in:
takatost
2024-08-13 16:21:10 +08:00
parent 8d27ec364f
commit 8401a11109
38 changed files with 976 additions and 2328 deletions

View File

@ -16,11 +16,13 @@ class BaseNode(ABC):
_node_type: NodeType
def __init__(self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type

View File

@ -7,7 +7,8 @@ class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
"""
Get stream generate routes.
@ -19,6 +20,10 @@ class EndStreamGeneratorRouter:
if not node_config.get('data', {}).get('type') == NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Any, cast
from configs import dify_config
@ -123,10 +124,14 @@ class IterationNode(BaseNode):
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": 1
@ -135,6 +140,7 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
@ -186,6 +192,7 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
@ -197,9 +204,11 @@ class IterationNode(BaseNode):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
@ -222,9 +231,11 @@ class IterationNode(BaseNode):
yield event
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
@ -247,9 +258,11 @@ class IterationNode(BaseNode):
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)

View File

@ -1,3 +1,4 @@
import logging
from typing import Any, cast
from sqlalchemy import func
@ -20,6 +21,8 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
@ -67,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

View File

@ -168,7 +168,8 @@ class LLMNode(BaseNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
)

View File

@ -175,7 +175,8 @@ class ParameterExtractorNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
def _invoke_llm(self, node_data_model: ModelConfig,

View File

@ -119,7 +119,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
except ValueError as e:
@ -131,7 +132,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
@classmethod