mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -225,14 +225,15 @@ class AnnotationBatchImportApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource):
|
||||
# filter out apps that user doesn't have access to
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
|
||||
# Pre-filter out apps without setting or with sso_verified
|
||||
filtered_installed_apps = []
|
||||
app_id_to_app_code = {}
|
||||
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
app_id = installed_app["app"].id
|
||||
webapp_setting = webapp_settings.get(app_id)
|
||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_code=app_code,
|
||||
):
|
||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
||||
app_id_to_app_code[app_id] = app_code
|
||||
filtered_installed_apps.append(installed_app)
|
||||
|
||||
app_codes = list(app_id_to_app_code.values())
|
||||
|
||||
# Batch permission check
|
||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||
user_id=user_id,
|
||||
app_codes=app_codes,
|
||||
)
|
||||
|
||||
# Keep only allowed apps
|
||||
res = []
|
||||
for installed_app in filtered_installed_apps:
|
||||
app_id = installed_app["app"].id
|
||||
app_code = app_id_to_app_code[app_id]
|
||||
if permissions.get(app_code):
|
||||
res.append(installed_app)
|
||||
|
||||
installed_app_list = res
|
||||
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)
|
||||
|
||||
|
||||
@ -208,6 +208,7 @@ class BasePluginClient:
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import Mapping
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from extensions.ext_logging import get_request_id
|
||||
|
||||
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
@ -11,7 +13,7 @@ class PluginDaemonError(Exception):
|
||||
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"{self.__class__.__name__}: {self.description}"
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
|
||||
class PluginDaemonInternalError(PluginDaemonError):
|
||||
|
||||
@ -5,14 +5,13 @@ from __future__ import annotations
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
TokenTextSplitter,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -20,9 +20,6 @@ class Tool(ABC):
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
@ -20,8 +20,6 @@ class BuiltinTool(Tool):
|
||||
:param meta: the meta data of a tool call processing
|
||||
"""
|
||||
|
||||
provider: str
|
||||
|
||||
def __init__(self, provider: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
@ -21,9 +21,6 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiToolBundle
|
||||
provider_id: str
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
|
||||
@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
|
||||
@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
@ -21,7 +16,7 @@ class PluginTool(Tool):
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
self.runtime_parameters: Optional[list[ToolParameter]] = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
@ -25,15 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
workflow_as_tool_id: str
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
|
||||
@ -136,6 +136,8 @@ def init_app(app: DifyApp):
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
@ -234,6 +236,8 @@ def init_app(app: DifyApp):
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
RedisInstrumentor().instrument()
|
||||
RequestsInstrumentor().instrument()
|
||||
atexit.register(shutdown_tracer)
|
||||
|
||||
|
||||
|
||||
@ -895,6 +895,19 @@ class WorkflowAppLog(Base):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"created_from": self.created_from,
|
||||
"created_by_role": self.created_by_role,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class ConversationVariable(Base):
|
||||
__tablename__ = "workflow_conversation_variables"
|
||||
|
||||
@ -49,6 +49,8 @@ dependencies = [
|
||||
"opentelemetry-instrumentation==0.48b0",
|
||||
"opentelemetry-instrumentation-celery==0.48b0",
|
||||
"opentelemetry-instrumentation-flask==0.48b0",
|
||||
"opentelemetry-instrumentation-redis==0.48b0",
|
||||
"opentelemetry-instrumentation-requests==0.48b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
|
||||
"opentelemetry-propagator-b3==1.27.0",
|
||||
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
|
||||
|
||||
@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -21,6 +33,85 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
|
||||
"""
|
||||
Clean up message-related tables to avoid data redundancy.
|
||||
This method cleans up tables that have foreign key relationships with Message.
|
||||
|
||||
Args:
|
||||
session: Database session, the same with the one in process_tenant method
|
||||
tenant_id: Tenant ID for logging purposes
|
||||
batch_message_ids: List of message IDs to clean up
|
||||
"""
|
||||
if not batch_message_ids:
|
||||
return
|
||||
|
||||
# Clean up each related table
|
||||
related_tables = [
|
||||
(MessageFeedback, "message_feedbacks"),
|
||||
(MessageFile, "message_files"),
|
||||
(MessageAnnotation, "message_annotations"),
|
||||
(MessageChain, "message_chains"),
|
||||
(MessageAgentThought, "message_agent_thoughts"),
|
||||
(AppAnnotationHitHistory, "app_annotation_hit_histories"),
|
||||
(SavedMessage, "saved_messages"),
|
||||
]
|
||||
|
||||
for model, table_name in related_tables:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.filter(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
continue
|
||||
|
||||
# Save records before deletion
|
||||
record_ids = [record.id for record in records]
|
||||
try:
|
||||
record_data = []
|
||||
for record in records:
|
||||
try:
|
||||
if hasattr(record, "to_dict"):
|
||||
record_data.append(record.to_dict())
|
||||
else:
|
||||
# if record doesn't have to_dict method, we need to transform it to dict manually
|
||||
record_dict = {}
|
||||
for column in record.__table__.columns:
|
||||
record_dict[column.name] = getattr(record, column.name)
|
||||
record_data.append(record_dict)
|
||||
except Exception:
|
||||
logger.exception("Failed to transform %s record: %s", table_name, record.id)
|
||||
continue
|
||||
|
||||
if record_data:
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(record_data),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).filter(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
|
||||
f"{table_name} records for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
|
||||
with flask_app.app_context():
|
||||
@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
if len(workflow_runs) < batch:
|
||||
break
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
.filter(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(workflow_app_logs) == 0:
|
||||
break
|
||||
|
||||
# save workflow app logs
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
[workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
|
||||
),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).filter(
|
||||
WorkflowAppLog.id.in_(workflow_app_log_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
|
||||
f" workflow app logs for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process(cls, days: int, batch: int, tenant_ids: list[str]):
|
||||
"""
|
||||
|
||||
@ -52,6 +52,16 @@ class EnterpriseService:
|
||||
|
||||
return data.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
|
||||
if not app_codes:
|
||||
return {}
|
||||
body = {"userId": user_id, "appCodes": app_codes}
|
||||
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return data.get("permissions", {})
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||
if not app_id:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,739 @@
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from models import App, Workflow
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import (
|
||||
UpdateNotSupportedError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
"""
|
||||
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
|
||||
|
||||
This test class covers all major functionality of the WorkflowDraftVariableService:
|
||||
- CRUD operations for workflow draft variables (Create, Read, Update, Delete)
|
||||
- Variable listing and filtering by type (conversation, system, node)
|
||||
- Variable updates and resets with proper validation
|
||||
- Variable deletion operations at different scopes
|
||||
- Special functionality like prefill and conversation ID retrieval
|
||||
- Error handling for various edge cases and invalid operations
|
||||
|
||||
All tests use the testcontainers infrastructure to ensure proper database isolation
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""
|
||||
Mock setup for external service dependencies.
|
||||
|
||||
WorkflowDraftVariableService doesn't have external dependencies that need mocking,
|
||||
so this fixture returns an empty dictionary to maintain consistency with other test classes.
|
||||
This ensures the test structure remains consistent across different service test files.
|
||||
"""
|
||||
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
|
||||
return {}
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
|
||||
"""
|
||||
Helper method to create a test app with realistic data for testing.
|
||||
|
||||
This method creates a complete App instance with all required fields populated
|
||||
using Faker for generating realistic test data. The app is configured for
|
||||
workflow mode to support workflow draft variable testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies (unused in this service)
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
App: Created test app instance with all required fields populated
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
app = App()
|
||||
app.id = fake.uuid4()
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = "workflow"
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
app.enable_site = True
|
||||
app.enable_api = True
|
||||
app.created_by = fake.uuid4()
|
||||
app.updated_by = app.created_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
return app
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
|
||||
"""
|
||||
Helper method to create a test workflow associated with an app.
|
||||
|
||||
This method creates a Workflow instance using the proper factory method
|
||||
to ensure all required fields are set correctly. The workflow is configured
|
||||
as a draft version with basic graph structure for testing workflow variables.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: The app to associate the workflow with
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
Workflow: Created test workflow instance with proper configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by=app.created_by,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
return workflow
|
||||
|
||||
def _create_test_variable(
|
||||
self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow draft variable with proper configuration.
|
||||
|
||||
This method creates different types of variables (conversation, system, node) using
|
||||
the appropriate factory methods to ensure proper initialization. Each variable type
|
||||
has specific requirements and this method handles the creation logic for all types.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app_id: ID of the app to associate the variable with
|
||||
node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID)
|
||||
name: Name of the variable for identification
|
||||
value: StringSegment value for the variable content
|
||||
variable_type: Type of variable ("conversation", "system", "node") determining creation method
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
WorkflowDraftVariable: Created test variable instance with proper type configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
if variable_type == "conversation":
|
||||
# Create conversation variable using the appropriate factory method
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
description=fake.text(max_nb_chars=20),
|
||||
)
|
||||
elif variable_type == "system":
|
||||
# Create system variable with editable flag and execution context
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=True,
|
||||
)
|
||||
else: # node variable
|
||||
# Create node variable with visibility and editability settings
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a single variable by ID successfully.
|
||||
|
||||
This test verifies that the service can retrieve a specific variable
|
||||
by its ID and that the returned variable contains the correct data.
|
||||
It ensures the basic CRUD read operation works correctly for workflow draft variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(variable.id)
|
||||
assert retrieved_variable is not None
|
||||
assert retrieved_variable.id == variable.id
|
||||
assert retrieved_variable.name == "test_var"
|
||||
assert retrieved_variable.app_id == app.id
|
||||
assert retrieved_variable.get_value().value == test_value.value
|
||||
|
||||
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a variable that doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve a variable with a non-existent ID. This ensures proper
|
||||
handling of missing data scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = fake.uuid4()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(non_existent_id)
|
||||
assert retrieved_variable is None
|
||||
|
||||
def test_get_draft_variables_by_selectors_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by selectors successfully.
|
||||
|
||||
This test verifies that the service can retrieve multiple variables
|
||||
using selector pairs (node_id, variable_name) and returns the correct
|
||||
variables for each selector. This is useful for bulk variable retrieval
|
||||
operations in workflow execution contexts.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake
|
||||
)
|
||||
var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
|
||||
)
|
||||
var3 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake
|
||||
)
|
||||
selectors = [
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var2"],
|
||||
["test_node_1", "var3"],
|
||||
]
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
|
||||
assert len(retrieved_variables) == 3
|
||||
var_names = [var.name for var in retrieved_variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" in var_names
|
||||
for var in retrieved_variables:
|
||||
if var.name == "var1":
|
||||
assert var.get_value().value == var1_value.value
|
||||
elif var.name == "var2":
|
||||
assert var.get_value().value == var2_value.value
|
||||
elif var.name == "var3":
|
||||
assert var.get_value().value == var3_value.value
|
||||
|
||||
def test_list_variables_without_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing variables without values successfully with pagination.
|
||||
|
||||
This test verifies that the service can list variables with pagination
|
||||
and that the returned variables don't include their values (for performance).
|
||||
This is important for scenarios where only variable metadata is needed
|
||||
without loading the actual content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(5):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3)
|
||||
assert result.total == 5
|
||||
assert len(result.variables) == 3
|
||||
assert result.variables[0].created_at >= result.variables[1].created_at
|
||||
assert result.variables[1].created_at >= result.variables[2].created_at
|
||||
for var in result.variables:
|
||||
assert var.name is not None
|
||||
assert var.app_id == app.id
|
||||
|
||||
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
variables associated with a specific node ID. This is crucial for
|
||||
workflow execution where variables need to be scoped to specific nodes.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake)
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_node_variables(app.id, node_id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == node_id
|
||||
assert var.app_id == app.id
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" not in var_names
|
||||
|
||||
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing conversation variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
conversation variables, excluding system and node variables.
|
||||
Conversation variables are user-facing variables that can be
|
||||
modified during conversation flows.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conv_var1_value = StringSegment(value=fake.word())
|
||||
conv_var2_value = StringSegment(value=fake.word())
|
||||
conv_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake
|
||||
)
|
||||
conv_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake
|
||||
)
|
||||
sys_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_conversation_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
assert "sys_var" not in var_names
|
||||
|
||||
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating a variable's name and value successfully.
|
||||
|
||||
This test verifies that the service can update both the name and value
|
||||
of an editable variable and that the changes are persisted correctly.
|
||||
It also checks that the last_edited_at timestamp is updated appropriately.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"original_name",
|
||||
original_value,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
updated_variable = service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert updated_variable.name == "new_name"
|
||||
assert updated_variable.get_value().value == new_value.value
|
||||
assert updated_variable.last_edited_at is not None
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(variable)
|
||||
assert variable.name == "new_name"
|
||||
assert variable.get_value().value == new_value.value
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that updating a non-editable variable raises an exception.
|
||||
|
||||
This test verifies that the service properly prevents updates to
|
||||
variables that are not marked as editable. This is important for
|
||||
maintaining data integrity and preventing unauthorized modifications
|
||||
to system-controlled variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app.id,
|
||||
name=fake.word(), # This is typically not editable
|
||||
value=original_value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=False, # Set as non-editable
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
with pytest.raises(UpdateNotSupportedError) as exc_info:
|
||||
service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert "variable not support updating" in str(exc_info.value)
|
||||
assert variable.id in str(exc_info.value)
|
||||
|
||||
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test resetting conversation variable successfully.
|
||||
|
||||
This test verifies that the service can reset a conversation variable
|
||||
to its default value and clear the last_edited_at timestamp.
|
||||
This functionality is useful for reverting user modifications
|
||||
back to the original workflow configuration.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="test_conv_var",
|
||||
value="default_value",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
modified_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"test_conv_var",
|
||||
modified_value,
|
||||
fake=fake,
|
||||
)
|
||||
variable.last_edited_at = fake.date_time()
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
reset_variable = service.reset_variable(workflow, variable)
|
||||
assert reset_variable is not None
|
||||
assert reset_variable.get_value().value == "default_value"
|
||||
assert reset_variable.last_edited_at is None
|
||||
db.session.refresh(variable)
|
||||
assert variable.get_value().value == "default_value"
|
||||
assert variable.last_edited_at is None
|
||||
|
||||
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting a single variable successfully.
|
||||
|
||||
This test verifies that the service can delete a specific variable
|
||||
and that it's properly removed from the database. It ensures that
|
||||
the deletion operation is atomic and complete.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_variable(variable)
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
|
||||
|
||||
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a workflow successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific app/workflow. This is useful for
|
||||
cleanup operations when workflows are deleted or reset.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(3):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
)
|
||||
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
other_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables) == 3
|
||||
assert len(other_app_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_workflow_variables(app.id)
|
||||
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables_after) == 0
|
||||
assert len(other_app_variables_after) == 1
|
||||
|
||||
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific node while preserving variables
|
||||
for other nodes and conversation variables. This is important
|
||||
for node-specific cleanup operations in workflow management.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
for i in range(2):
|
||||
test_value = StringSegment(value=fake.numerify("node_value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake
|
||||
)
|
||||
other_node_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake
|
||||
)
|
||||
conv_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
other_node_variables = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables) == 2
|
||||
assert len(other_node_variables) == 1
|
||||
assert len(conv_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_node_variables(app.id, node_id)
|
||||
target_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
other_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables_after) == 0
|
||||
assert len(other_node_variables_after) == 1
|
||||
assert len(conv_variables_after) == 1
|
||||
|
||||
def test_prefill_conversation_variable_default_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prefill conversation variable default values successfully.
|
||||
|
||||
This test verifies that the service can automatically create
|
||||
conversation variables with default values based on the workflow
|
||||
configuration when none exist. This is important for initializing
|
||||
workflow variables with proper defaults from the workflow definition.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var1 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var1",
|
||||
value="default_value1",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"],
|
||||
)
|
||||
conv_var2 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var2",
|
||||
value="default_value2",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var1, conv_var2]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.prefill_conversation_variable_default_values(workflow)
|
||||
draft_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(draft_variables) == 2
|
||||
var_names = [var.name for var in draft_variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
for var in draft_variables:
|
||||
assert var.app_id == app.id
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.editable is True
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID from draft variable successfully.
|
||||
|
||||
This test verifies that the service can extract the conversation ID
|
||||
from a system variable named "conversation_id". This is important
|
||||
for maintaining conversation context across workflow executions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conversation_id = fake.uuid4()
|
||||
conv_id_value = StringSegment(value=conversation_id)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"conversation_id",
|
||||
conv_id_value,
|
||||
"system",
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id == conversation_id
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID when it doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when no
|
||||
conversation_id variable exists for the app. This ensures
|
||||
proper handling of missing conversation context scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id is None
|
||||
|
||||
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing system variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
system variables, excluding conversation and node variables.
|
||||
System variables are internal variables used by the workflow
|
||||
engine for maintaining state and context.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
sys_var1_value = StringSegment(value=fake.word())
|
||||
sys_var2_value = StringSegment(value=fake.word())
|
||||
sys_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake
|
||||
)
|
||||
sys_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake
|
||||
)
|
||||
conv_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_system_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.SYS
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "sys_var1" in var_names
|
||||
assert "sys_var2" in var_names
|
||||
assert "conv_var" not in var_names
|
||||
|
||||
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name successfully for different types.
|
||||
|
||||
This test verifies that the service can retrieve variables by name
|
||||
for different variable types (conversation, system, node). This
|
||||
functionality is important for variable lookup operations during
|
||||
workflow execution and user interactions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
conv_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
|
||||
)
|
||||
sys_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake
|
||||
)
|
||||
node_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
|
||||
assert retrieved_conv_var is not None
|
||||
assert retrieved_conv_var.name == "test_conv_var"
|
||||
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
|
||||
assert retrieved_sys_var is not None
|
||||
assert retrieved_sys_var.name == "test_sys_var"
|
||||
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
|
||||
assert retrieved_node_var is not None
|
||||
assert retrieved_node_var.name == "test_node_var"
|
||||
assert retrieved_node_var.node_id == "test_node"
|
||||
|
||||
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name when they don't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve variables by name that don't exist. This ensures proper
|
||||
handling of missing variable scenarios for all variable types.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
|
||||
assert retrieved_conv_var is None
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
|
||||
assert retrieved_sys_var is None
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
|
||||
assert retrieved_node_var is None
|
||||
@ -0,0 +1,168 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
||||
|
||||
class TestClearFreePlanTenantExpiredLogs:
|
||||
"""Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = Mock(spec=Session)
|
||||
session.query.return_value.filter.return_value.all.return_value = []
|
||||
session.query.return_value.filter.return_value.delete.return_value = 0
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage object."""
|
||||
storage = Mock()
|
||||
storage.save.return_value = None
|
||||
return storage
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_ids(self):
|
||||
"""Sample message IDs for testing."""
|
||||
return ["msg-1", "msg-2", "msg-3"]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_records(self):
|
||||
"""Sample records for testing."""
|
||||
records = []
|
||||
for i in range(3):
|
||||
record = Mock()
|
||||
record.id = f"record-{i}"
|
||||
record.to_dict.return_value = {
|
||||
"id": f"record-{i}",
|
||||
"message_id": f"msg-{i}",
|
||||
"created_at": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def test_clear_message_related_tables_empty_message_ids(self, mock_session):
|
||||
"""Test that method returns early when message_ids is empty."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
|
||||
|
||||
# Should not call any database operations
|
||||
mock_session.query.assert_not_called()
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
|
||||
"""Test when no related records are found."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call query for each related table but find no records
|
||||
assert mock_session.query.call_count > 0
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_with_records_and_to_dict(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test when records are found and have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call to_dict on each record (called once per table, so 7 times total)
|
||||
for record in sample_records:
|
||||
assert record.to_dict.call_count == 7
|
||||
|
||||
# Should save backup data
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids):
|
||||
"""Test when records are found but don't have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
# Create records without to_dict method
|
||||
records = []
|
||||
for i in range(2):
|
||||
record = Mock()
|
||||
mock_table = Mock()
|
||||
mock_id_column = Mock()
|
||||
mock_id_column.name = "id"
|
||||
mock_message_id_column = Mock()
|
||||
mock_message_id_column.name = "message_id"
|
||||
mock_table.columns = [mock_id_column, mock_message_id_column]
|
||||
record.__table__ = mock_table
|
||||
record.id = f"record-{i}"
|
||||
record.message_id = f"msg-{i}"
|
||||
del record.to_dict
|
||||
records.append(record)
|
||||
|
||||
# Mock records for first table only, empty for others
|
||||
mock_session.query.return_value.filter.return_value.all.side_effect = [
|
||||
records,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should save backup data even without to_dict
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_storage_error_continues(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test that method continues even when storage.save fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_storage.save.side_effect = Exception("Storage error")
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if backup fails
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
|
||||
"""Test that method continues even when record serialization fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
record = Mock()
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [record]
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if serialization fails
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
|
||||
"""Test that deletion is called for found records."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_logging_output(
|
||||
self, mock_session, sample_message_ids, sample_records, capsys
|
||||
):
|
||||
"""Test that logging output is generated."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
pass
|
||||
36
api/uv.lock
generated
36
api/uv.lock
generated
@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
|
||||
@ -1265,6 +1265,8 @@ dependencies = [
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-instrumentation-celery" },
|
||||
{ name = "opentelemetry-instrumentation-flask" },
|
||||
{ name = "opentelemetry-instrumentation-redis" },
|
||||
{ name = "opentelemetry-instrumentation-requests" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy" },
|
||||
{ name = "opentelemetry-propagator-b3" },
|
||||
{ name = "opentelemetry-proto" },
|
||||
@ -1448,6 +1450,8 @@ requires-dist = [
|
||||
{ name = "opentelemetry-instrumentation", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-propagator-b3", specifier = "==1.27.0" },
|
||||
{ name = "opentelemetry-proto", specifier = "==1.27.0" },
|
||||
@ -3670,6 +3674,36 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-redis"
|
||||
version = "0.48b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-requests"
|
||||
version = "0.48b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "opentelemetry-util-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-sqlalchemy"
|
||||
version = "0.48b0"
|
||||
|
||||
Reference in New Issue
Block a user