chore: change draft var to user scoped (#33066)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
非法操作
2026-03-16 14:04:41 +08:00
committed by GitHub
parent df570df238
commit 98e72521f4
19 changed files with 452 additions and 130 deletions

View File

@ -77,6 +77,7 @@ class DraftVarLoader(VariableLoader):
_engine: Engine
# Application ID for which variables are being loaded.
_app_id: str
_user_id: str
_tenant_id: str
_fallback_variables: Sequence[VariableBase]
@ -85,10 +86,12 @@ class DraftVarLoader(VariableLoader):
engine: Engine,
app_id: str,
tenant_id: str,
user_id: str,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
self._user_id = user_id
self._tenant_id = tenant_id
self._fallback_variables = fallback_variables or []
@ -104,7 +107,7 @@ class DraftVarLoader(VariableLoader):
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors, user_id=self._user_id)
# Important:
files: list[File] = []
@ -218,6 +221,7 @@ class WorkflowDraftVariableService:
self,
app_id: str,
selectors: Sequence[list[str]],
user_id: str,
) -> list[WorkflowDraftVariable]:
"""
Retrieve WorkflowDraftVariable instances based on app_id and selectors.
@ -238,22 +242,30 @@ class WorkflowDraftVariableService:
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
query = (
self._session.query(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
)
)
.where(WorkflowDraftVariable.app_id == app_id, or_(*ors))
.all()
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
or_(*ors),
)
)
return variables
return query.all()
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
def list_variables_without_values(
self, app_id: str, page: int, limit: int, user_id: str
) -> WorkflowDraftVariableList:
criteria = [
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
]
total = None
query = self._session.query(WorkflowDraftVariable).where(criteria)
query = self._session.query(WorkflowDraftVariable).where(*criteria)
if page == 1:
total = query.count()
variables = (
@ -269,11 +281,12 @@ class WorkflowDraftVariableService:
return WorkflowDraftVariableList(variables=variables, total=total)
def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
criteria = (
def _list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
criteria = [
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
WorkflowDraftVariable.user_id == user_id,
]
query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = (
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
@ -282,36 +295,36 @@ class WorkflowDraftVariableService:
)
return WorkflowDraftVariableList(variables=variables)
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id)
def list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id, user_id=user_id)
def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)
def list_conversation_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID, user_id=user_id)
def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)
def list_system_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID, user_id=user_id)
def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)
def get_conversation_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, user_id=user_id)
def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)
def get_system_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, user_id=user_id)
def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name)
def get_node_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name, user_id=user_id)
def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
variable = (
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return (
self._session.query(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.name == name,
WorkflowDraftVariable.user_id == user_id,
)
.first()
)
return variable
def update_variable(
self,
@ -462,7 +475,17 @@ class WorkflowDraftVariableService:
self._session.delete(upload_file)
self._session.delete(variable)
def delete_workflow_variables(self, app_id: str):
def delete_user_workflow_variables(self, app_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
)
def delete_app_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app_id)
@ -501,28 +524,35 @@ class WorkflowDraftVariableService:
self._session.delete(upload_file)
self._session.delete(variable_file)
def delete_node_variables(self, app_id: str, node_id: str):
return self._delete_node_variables(app_id, node_id)
def delete_node_variables(self, app_id: str, node_id: str, user_id: str):
return self._delete_node_variables(app_id, node_id, user_id=user_id)
def _delete_node_variables(self, app_id: str, node_id: str):
self._session.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
).delete()
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
)
def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None:
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
draft_var = self._get_variable(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=str(SystemVariableKey.CONVERSATION_ID),
user_id=user_id,
)
if draft_var is None:
return None
segment = draft_var.get_value()
if not isinstance(segment, StringSegment):
logger.warning(
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
"sys.conversation_id variable is not a string: app_id=%s, user_id=%s, id=%s",
app_id,
user_id,
draft_var.id,
)
return None
@ -543,7 +573,7 @@ class WorkflowDraftVariableService:
If no such conversation exists, a new conversation is created and its ID is returned.
"""
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
if conv_id is not None:
conversation = (
@ -580,12 +610,13 @@ class WorkflowDraftVariableService:
self._session.flush()
return conversation.id
def prefill_conversation_variable_default_values(self, workflow: Workflow):
def prefill_conversation_variable_default_values(self, workflow: Workflow, user_id: str):
""""""
draft_conv_vars: list[WorkflowDraftVariable] = []
for conv_var in workflow.conversation_variables:
draft_var = WorkflowDraftVariable.new_conversation_variable(
app_id=workflow.app_id,
user_id=user_id,
name=conv_var.name,
value=conv_var,
description=conv_var.description,
@ -635,7 +666,7 @@ def _batch_upsert_draft_variable(
stmt = pg_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
if policy == _UpsertPolicy.OVERWRITE:
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name(),
set_={
# Refresh creation timestamp to ensure updated variables
# appear first in chronologically sorted result sets.
@ -652,7 +683,9 @@ def _batch_upsert_draft_variable(
},
)
elif policy == _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
stmt = stmt.on_conflict_do_nothing(
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name()
)
else:
stmt = mysql_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) # type: ignore[assignment]
if policy == _UpsertPolicy.OVERWRITE:
@ -682,6 +715,7 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
"id": model.id,
"app_id": model.app_id,
"user_id": model.user_id,
"last_edited_at": None,
"node_id": model.node_id,
"name": model.name,
@ -807,6 +841,7 @@ class DraftVariableSaver:
def _create_dummy_output_variable(self):
return WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=self._DUMMY_OUTPUT_IDENTITY,
node_execution_id=self._node_execution_id,
@ -842,6 +877,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_conversation_variable(
app_id=self._app_id,
user_id=self._user.id,
name=item.name,
value=segment,
)
@ -862,6 +898,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
@ -884,6 +921,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_sys_variable(
app_id=self._app_id,
user_id=self._user.id,
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
@ -1019,6 +1057,7 @@ class DraftVariableSaver:
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
@ -1032,6 +1071,7 @@ class DraftVariableSaver:
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,