mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor: add node_id to MemoryBlockSpec
This commit is contained in:
@ -4,7 +4,7 @@ from enum import StrEnum
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.app.app_config.entities import ModelConfig
|
from core.app.app_config.entities import ModelConfig
|
||||||
|
|
||||||
@ -49,6 +49,21 @@ class MemoryBlockSpec(BaseModel):
|
|||||||
model: ModelConfig = Field(description="Model configuration for memory updates")
|
model: ModelConfig = Field(description="Model configuration for memory updates")
|
||||||
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
||||||
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
||||||
|
node_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Node ID when scope is NODE. Must be None when scope is APP."
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('node_id')
|
||||||
|
@classmethod
|
||||||
|
def validate_node_id_with_scope(cls, v: str | None, info) -> str | None:
|
||||||
|
"""Validate node_id consistency with scope"""
|
||||||
|
scope = info.data.get('scope')
|
||||||
|
if scope == MemoryScope.NODE and v is None:
|
||||||
|
raise ValueError("node_id is required when scope is NODE")
|
||||||
|
if scope == MemoryScope.APP and v is not None:
|
||||||
|
raise ValueError("node_id must be None when scope is APP")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class MemoryCreatedBy(BaseModel):
|
class MemoryCreatedBy(BaseModel):
|
||||||
|
|||||||
@ -44,13 +44,7 @@ class MemoryConfig(BaseModel):
|
|||||||
|
|
||||||
enabled: bool
|
enabled: bool
|
||||||
size: int | None = None
|
size: int | None = None
|
||||||
|
|
||||||
mode: Literal["linear", "block"] | None = "linear"
|
mode: Literal["linear", "block"] | None = "linear"
|
||||||
block_id: list[str] | None = None
|
|
||||||
role_prefix: RolePrefix | None = None
|
role_prefix: RolePrefix | None = None
|
||||||
window: WindowConfig
|
window: WindowConfig
|
||||||
query_prompt_template: str | None = None
|
query_prompt_template: str | None = None
|
||||||
|
|
||||||
@property
|
|
||||||
def is_block_mode(self) -> bool:
|
|
||||||
return self.mode == "block" and bool(self.block_id)
|
|
||||||
|
|||||||
@ -1234,13 +1234,6 @@ class LLMNode(Node):
|
|||||||
tenant_id=self.tenant_id
|
tenant_id=self.tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config = self._node_data.memory
|
|
||||||
if not memory_config:
|
|
||||||
return
|
|
||||||
block_ids = memory_config.block_id
|
|
||||||
if not block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
stmt = select(Workflow).where(
|
stmt = select(Workflow).where(
|
||||||
@ -1250,24 +1243,31 @@ class LLMNode(Node):
|
|||||||
workflow = session.scalars(stmt).first()
|
workflow = session.scalars(stmt).first()
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("Workflow not found.")
|
raise ValueError("Workflow not found.")
|
||||||
memory_blocks = workflow.memory_blocks
|
|
||||||
|
|
||||||
for block_id in block_ids:
|
# Filter memory blocks that belong to this node
|
||||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
|
node_memory_blocks = [
|
||||||
|
block for block in workflow.memory_blocks
|
||||||
|
if block.scope == MemoryScope.NODE and block.node_id == self.id
|
||||||
|
]
|
||||||
|
|
||||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
if not node_memory_blocks:
|
||||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
return
|
||||||
from services.chatflow_memory_service import ChatflowMemoryService
|
|
||||||
ChatflowMemoryService.update_node_memory_if_needed(
|
# Update each memory block that belongs to this node
|
||||||
tenant_id=self.tenant_id,
|
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||||
app_id=self.app_id,
|
from services.chatflow_memory_service import ChatflowMemoryService
|
||||||
node_id=self.id,
|
|
||||||
conversation_id=conversation_id,
|
for memory_block_spec in node_memory_blocks:
|
||||||
memory_block_spec=memory_block_spec,
|
ChatflowMemoryService.update_node_memory_if_needed(
|
||||||
variable_pool=variable_pool,
|
tenant_id=self.tenant_id,
|
||||||
is_draft=is_draft,
|
app_id=self.app_id,
|
||||||
created_by=self._get_user_from_context()
|
node_id=self.id,
|
||||||
)
|
conversation_id=conversation_id,
|
||||||
|
memory_block_spec=memory_block_spec,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
is_draft=is_draft,
|
||||||
|
created_by=self._get_user_from_context()
|
||||||
|
)
|
||||||
|
|
||||||
def _get_user_from_context(self) -> MemoryCreatedBy:
|
def _get_user_from_context(self) -> MemoryCreatedBy:
|
||||||
if self.user_from == UserFrom.ACCOUNT:
|
if self.user_from == UserFrom.ACCOUNT:
|
||||||
|
|||||||
@ -71,6 +71,7 @@ memory_block_fields = {
|
|||||||
"model": fields.Nested(model_config_fields),
|
"model": fields.Nested(model_config_fields),
|
||||||
"end_user_visible": fields.Boolean,
|
"end_user_visible": fields.Boolean,
|
||||||
"end_user_editable": fields.Boolean,
|
"end_user_editable": fields.Boolean,
|
||||||
|
"node_id": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pipeline_variable_fields = {
|
pipeline_variable_fields = {
|
||||||
|
|||||||
Reference in New Issue
Block a user