Merge branch 'feat/iteration-node' into deploy/dev

This commit is contained in:
CodingOnStar
2025-10-24 15:31:04 +08:00
20 changed files with 524 additions and 285 deletions

View File

@ -32,6 +32,7 @@ from libs.token import (
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@ -273,7 +274,7 @@ class EmailCodeLoginApi(Resource):
class RefreshTokenApi(Resource):
def post(self):
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
refresh_token = extract_refresh_token(request)
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401

View File

@ -193,15 +193,19 @@ class QuestionClassifierNode(Node):
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
rendered_classes = [
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
]
category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = node_data.classes
classes = rendered_classes
classes_map = {class_.id: class_.name for class_ in classes}
category_ids = [_class.id for _class in classes]
if category_id_result in category_ids:

View File

@ -5,6 +5,7 @@ import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
@ -106,6 +107,23 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
start_at: float
total_tokens: int
node_run_steps: int
llm_usage: LLMUsage
outputs: dict[str, Any]
variable_pool: VariablePool
has_variable_pool: bool
ready_queue_dump: str | None
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
@ -293,69 +311,28 @@ class GraphRuntimeState:
return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None:
@classmethod
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
"""Restore runtime state from a serialized snapshot."""
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
snapshot = cls._parse_snapshot_payload(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
state = cls(
variable_pool=snapshot.variable_pool,
start_at=snapshot.start_at,
total_tokens=snapshot.total_tokens,
llm_usage=snapshot.llm_usage,
outputs=snapshot.outputs,
node_run_steps=snapshot.node_run_steps,
)
state._apply_snapshot(snapshot)
return state
self._start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot (legacy API)."""
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
snapshot = self._parse_snapshot_payload(data)
self._apply_snapshot(snapshot)
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
@ -391,3 +368,106 @@ class GraphRuntimeState:
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
# ------------------------------------------------------------------
# Snapshot helpers
# ------------------------------------------------------------------
@classmethod
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
llm_usage_payload = payload.get("llm_usage", {})
llm_usage = LLMUsage.model_validate(llm_usage_payload)
outputs_payload = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
has_variable_pool = variable_pool_payload is not None
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
ready_queue_payload = payload.get("ready_queue")
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
return _GraphRuntimeStateSnapshot(
start_at=start_at,
total_tokens=total_tokens,
node_run_steps=node_run_steps,
llm_usage=llm_usage,
outputs=outputs_payload,
variable_pool=variable_pool,
has_variable_pool=has_variable_pool,
ready_queue_dump=ready_queue_payload,
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
self._start_at = snapshot.start_at
self._total_tokens = snapshot.total_tokens
self._node_run_steps = snapshot.node_run_steps
self._llm_usage = snapshot.llm_usage.model_copy()
self._outputs = deepcopy(snapshot.outputs)
if snapshot.has_variable_pool or self._variable_pool is None:
self._variable_pool = snapshot.variable_pool
self._restore_ready_queue(snapshot.ready_queue_dump)
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(payload)
else:
self._ready_queue = None
def _restore_graph_execution(self, payload: str | None) -> None:
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if payload is None:
return
try:
execution_payload = json.loads(payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(payload)
def _restore_response_coordinator(self, payload: str | None) -> None:
if payload is None:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
return
if self._graph is not None:
self.response_coordinator.loads(payload)
self._pending_response_coordinator_dump = None
return
self._pending_response_coordinator_dump = payload
self._response_coordinator = None

View File

@ -6,10 +6,11 @@ from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_access_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
@ -61,14 +62,30 @@ def load_user_from_request(request_from_flask_login):
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
if webapp_token:
decoded = PassportService().verify(webapp_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
raise Unauthorized("Invalid Authorization token for web API.")
elif request.blueprint == "mcp":
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:

View File

@ -38,9 +38,6 @@ def _real_cookie_name(cookie_name: str) -> str:
def _try_extract_from_header(request: Request) -> str | None:
"""
Try to extract access token from header
"""
auth_header = request.headers.get("Authorization")
if auth_header:
if " " not in auth_header:
@ -55,27 +52,19 @@ def _try_extract_from_header(request: Request) -> str | None:
return None
def extract_refresh_token(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
def extract_csrf_token(request: Request) -> str | None:
"""
Try to extract CSRF token from header or cookie.
"""
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
def extract_csrf_token_from_cookie(request: Request) -> str | None:
"""
Try to extract CSRF token from cookie.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
"""
Try to extract access token from cookie, header or params.
Access token is either for console session or webapp passport exchange.
"""
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
@ -83,20 +72,10 @@ def extract_access_token(request: Request) -> str | None:
def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
Webapp access token (part of passport) is only used for webapp session.
"""
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))

View File

@ -82,54 +82,51 @@ class AudioService:
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:
workflow = app_model.workflow
if (
workflow is None
or "text_to_speech" not in workflow.features_dict
or not workflow.features_dict["text_to_speech"].get("enabled")
):
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:
workflow = app_model.workflow
if (
workflow is None
or "text_to_speech" not in workflow.features_dict
or not workflow.features_dict["text_to_speech"].get("enabled")
):
raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
voice = text_to_speech_dict.get("voice")
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice")
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
if message_id:
try:

View File

@ -283,7 +283,7 @@ class VariableTruncator:
break
remaining_budget = target_size - used_size
if item is None or isinstance(item, (str, list, dict, bool, int, float)):
if item is None or isinstance(item, (str, list, dict, bool, int, float, UpdatedVariable)):
part_result = self._truncate_json_primitives(item, remaining_budget)
else:
raise UnknownTypeError(f"got unknown type {type(item)} in array truncation")
@ -373,6 +373,11 @@ class VariableTruncator:
return _PartResult(truncated_obj, used_size, truncated)
@overload
def _truncate_json_primitives(
self, val: UpdatedVariable, target_size: int
) -> _PartResult[Mapping[str, object]]: ...
@overload
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...

View File

@ -8,6 +8,18 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
class TestGraphRuntimeState:
def test_property_getters_and_setters(self):
# FIXME(-LAN-): Mock VariablePool if needed
@ -191,17 +203,6 @@ class TestGraphRuntimeState:
graph_execution.exceptions_count = 4
graph_execution.started = True
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
mock_graph = MagicMock()
stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
@ -211,8 +212,7 @@ class TestGraphRuntimeState:
snapshot = state.dumps()
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.loads(snapshot)
restored = GraphRuntimeState.from_snapshot(snapshot)
assert restored.total_tokens == 10
assert restored.node_run_steps == 3
@ -235,3 +235,47 @@ class TestGraphRuntimeState:
restored.attach_graph(mock_graph)
assert new_stub.state == "configured"
def test_loads_rehydrates_existing_instance(self):
variable_pool = VariablePool()
variable_pool.add(("node", "key"), "value")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 7
state.node_run_steps = 2
state.set_output("foo", "bar")
state.ready_queue.put("node-1")
execution = state.graph_execution
execution.workflow_id = "wf-456"
execution.started = True
mock_graph = MagicMock()
original_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
state.attach_graph(mock_graph)
original_stub.state = "configured"
snapshot = state.dumps()
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.attach_graph(mock_graph)
restored.loads(snapshot)
assert restored.total_tokens == 7
assert restored.node_run_steps == 2
assert restored.get_output("foo") == "bar"
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-1"
restored_segment = restored.variable_pool.get(("node", "key"))
assert restored_segment is not None
assert restored_segment.value == "value"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-456"
assert restored_execution.started is True
assert new_stub.state == "configured"