mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
[Enhancement] Allow modify conversation variable via api (#23112)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
import json
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
@ -15,6 +17,7 @@ from fields.conversation_fields import (
|
||||
simple_conversation_fields,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
conversation_variable_infinite_scroll_pagination_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(conversation_variable_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("value", required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
raise NotFound("Conversation Variable Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
|
||||
api.add_resource(ConversationApi, "/conversations")
|
||||
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
|
||||
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")
|
||||
api.add_resource(
|
||||
ConversationVariableDetailApi,
|
||||
"/conversations/<uuid:c_id>/variables/<uuid:variable_id>",
|
||||
endpoint="conversation_variable_detail",
|
||||
methods=["PUT"],
|
||||
)
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import ConversationVariable
|
||||
@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
ConversationVariableTypeMismatchError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
from services.errors.message import MessageNotExistsError
|
||||
@ -220,3 +224,82 @@ class ConversationService:
|
||||
]
|
||||
|
||||
return InfiniteScrollPagination(variables, limit, has_more)
|
||||
|
||||
@classmethod
|
||||
def update_conversation_variable(
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
variable_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
new_value: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a conversation variable's value.
|
||||
|
||||
Args:
|
||||
app_model: The app model
|
||||
conversation_id: The conversation ID
|
||||
variable_id: The variable ID to update
|
||||
user: The user (Account or EndUser)
|
||||
new_value: The new value for the variable
|
||||
|
||||
Returns:
|
||||
Dictionary containing the updated variable information
|
||||
|
||||
Raises:
|
||||
ConversationNotExistsError: If the conversation doesn't exist
|
||||
ConversationVariableNotExistsError: If the variable doesn't exist
|
||||
ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type
|
||||
"""
|
||||
# Verify conversation exists and user has access
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
# Get the existing conversation variable
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.where(ConversationVariable.conversation_id == conversation.id)
|
||||
.where(ConversationVariable.id == variable_id)
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
existing_variable = session.scalar(stmt)
|
||||
if not existing_variable:
|
||||
raise ConversationVariableNotExistsError()
|
||||
|
||||
# Convert existing variable to Variable object
|
||||
current_variable = existing_variable.to_variable()
|
||||
|
||||
# Validate that the new value type matches the expected variable type
|
||||
expected_type = SegmentType(current_variable.value_type)
|
||||
if not expected_type.is_valid(new_value):
|
||||
inferred_type = SegmentType.infer_segment_type(new_value)
|
||||
raise ConversationVariableTypeMismatchError(
|
||||
f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, "
|
||||
f"but got {inferred_type.value if inferred_type else 'unknown'} type"
|
||||
)
|
||||
|
||||
# Create updated variable with new value only, preserving everything else
|
||||
updated_variable_dict = {
|
||||
"id": current_variable.id,
|
||||
"name": current_variable.name,
|
||||
"description": current_variable.description,
|
||||
"value_type": current_variable.value_type,
|
||||
"value": new_value,
|
||||
"selector": current_variable.selector,
|
||||
}
|
||||
|
||||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||
|
||||
# Use the conversation variable updater to persist the changes
|
||||
updater = conversation_variable_updater_factory()
|
||||
updater.update(conversation_id, updated_variable)
|
||||
updater.flush()
|
||||
|
||||
# Return the updated variable data
|
||||
return {
|
||||
"created_at": existing_variable.created_at,
|
||||
"updated_at": naive_utc_now(), # Update timestamp
|
||||
**updated_variable.model_dump(),
|
||||
}
|
||||
|
||||
@ -15,3 +15,7 @@ class ConversationCompletedError(Exception):
|
||||
|
||||
class ConversationVariableNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationVariableTypeMismatchError(BaseServiceError):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user