refactor: add node_id to MemoryBlockSpec

This commit is contained in:
Stream
2025-10-28 13:04:01 +08:00
parent 89d53ecf50
commit f0ff2e1f2c
4 changed files with 40 additions and 30 deletions

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from typing import Optional from typing import Optional
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from core.app.app_config.entities import ModelConfig from core.app.app_config.entities import ModelConfig
@ -49,6 +49,21 @@ class MemoryBlockSpec(BaseModel):
model: ModelConfig = Field(description="Model configuration for memory updates") model: ModelConfig = Field(description="Model configuration for memory updates")
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users") end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users") end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
node_id: str | None = Field(
default=None,
description="Node ID when scope is NODE. Must be None when scope is APP."
)
@field_validator('node_id')
@classmethod
def validate_node_id_with_scope(cls, v: str | None, info) -> str | None:
"""Validate node_id consistency with scope"""
scope = info.data.get('scope')
if scope == MemoryScope.NODE and v is None:
raise ValueError("node_id is required when scope is NODE")
if scope == MemoryScope.APP and v is not None:
raise ValueError("node_id must be None when scope is APP")
return v
class MemoryCreatedBy(BaseModel): class MemoryCreatedBy(BaseModel):

View File

@ -44,13 +44,7 @@ class MemoryConfig(BaseModel):
enabled: bool enabled: bool
size: int | None = None size: int | None = None
mode: Literal["linear", "block"] | None = "linear" mode: Literal["linear", "block"] | None = "linear"
block_id: list[str] | None = None
role_prefix: RolePrefix | None = None role_prefix: RolePrefix | None = None
window: WindowConfig window: WindowConfig
query_prompt_template: str | None = None query_prompt_template: str | None = None
@property
def is_block_mode(self) -> bool:
return self.mode == "block" and bool(self.block_id)

View File

@ -1234,13 +1234,6 @@ class LLMNode(Node):
tenant_id=self.tenant_id tenant_id=self.tenant_id
) )
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version # FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session: with Session(db.engine) as session:
stmt = select(Workflow).where( stmt = select(Workflow).where(
@ -1250,24 +1243,31 @@ class LLMNode(Node):
workflow = session.scalars(stmt).first() workflow = session.scalars(stmt).first()
if not workflow: if not workflow:
raise ValueError("Workflow not found.") raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids: # Filter memory blocks that belong to this node
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None) node_memory_blocks = [
block for block in workflow.memory_blocks
if block.scope == MemoryScope.NODE and block.node_id == self.id
]
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE: if not node_memory_blocks:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER) return
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed( # Update each memory block that belongs to this node
tenant_id=self.tenant_id, is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
app_id=self.app_id, from services.chatflow_memory_service import ChatflowMemoryService
node_id=self.id,
conversation_id=conversation_id, for memory_block_spec in node_memory_blocks:
memory_block_spec=memory_block_spec, ChatflowMemoryService.update_node_memory_if_needed(
variable_pool=variable_pool, tenant_id=self.tenant_id,
is_draft=is_draft, app_id=self.app_id,
created_by=self._get_user_from_context() node_id=self.id,
) conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy: def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT: if self.user_from == UserFrom.ACCOUNT:

View File

@ -71,6 +71,7 @@ memory_block_fields = {
"model": fields.Nested(model_config_fields), "model": fields.Nested(model_config_fields),
"end_user_visible": fields.Boolean, "end_user_visible": fields.Boolean,
"end_user_editable": fields.Boolean, "end_user_editable": fields.Boolean,
"node_id": fields.String,
} }
pipeline_variable_fields = { pipeline_variable_fields = {