diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index ef2f86d4be..56816dd462 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -348,10 +348,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -460,7 +463,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -469,11 +475,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5a536af6d2..16fecb41c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) # Search in both content and keywords fields # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": @@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource): .scalar_subquery() ), ",", - ).ilike(f"%{keyword}%") + ).ilike(f"%{escaped_keyword}%", escape="\\") else: # MySQL: Cast JSON to string for pattern matching # MySQL stores Chinese text directly in JSON without Unicode escaping - keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%") + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") query = query.where( or_( - DocumentSegment.content.ilike(f"%{keyword}%"), + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), keywords_condition, ) ) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..e05b70ba22 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..5bdb0af0b3 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -287,11 +287,15 @@ class IrisVector(BaseVector): cursor.execute(sql, (query,)) else: # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" sql = f""" SELECT TOP {top_k} id, text, meta FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + WHERE text LIKE ? ESCAPE '\\' """ cursor.execute(sql, (query_pattern,)) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c6339aa3ba..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1198,18 +1198,24 @@ class DatasetRetrieval: json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): diff --git a/api/libs/helper.py b/api/libs/helper.py index 74e1808e49..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 7f44fe05a6..b73302508a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..02ebfbace0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..f56de36ef7 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..18e5613438 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -144,7 +144,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -3423,7 +3424,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3458,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..8574d30255 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..5555400ca6 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -444,6 +444,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 85789bfa7e..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,6 @@ import pytest -from libs.helper import extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -63,3 +63,51 @@ class TestExtractTenantId: with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): extract_tenant_id(dict_user) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value"