Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-11 11:15:58 +08:00
214 changed files with 8987 additions and 838 deletions

View File

@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
db_conversation_variables = session.scalars(stmt).all()
if not db_conversation_variables:
# Create conversation variables if they don't exist.
db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in self._workflow.conversation_variables
]
session.add_all(db_conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()
# Initialize conversation variables
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
system_inputs = SystemVariable(
@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
"""
Initialize conversation variables for the current conversation.
This method:
1. Loads existing variables from the database
2. Creates new variables if none exist
3. Syncs missing variables from the workflow definition
:return: List of conversation variables ready for use
"""
with Session(db.engine) as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
# First time initialization - create all variables
existing_variables = self._create_all_conversation_variables(session)
else:
# Check and add any missing variables from the workflow
existing_variables = self._sync_missing_conversation_variables(session, existing_variables)
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Load existing conversation variables from the database.
:param session: Database session
:return: List of existing conversation variables
"""
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
return list(session.scalars(stmt).all())
def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Create all conversation variables for a new conversation.
:param session: Database session
:return: List of created conversation variables
"""
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in self._workflow.conversation_variables
]
if new_variables:
session.add_all(new_variables)
return new_variables
def _sync_missing_conversation_variables(
self, session: Session, existing_variables: list[ConversationVariable]
) -> list[ConversationVariable]:
"""
Sync missing conversation variables from the workflow definition.
This handles the case where new variables are added to a workflow
after conversations have already been created.
:param session: Database session
:param existing_variables: List of existing conversation variables
:return: Updated list including any newly created variables
"""
# Get IDs of existing and workflow variables
existing_ids = {var.id for var in existing_variables}
workflow_variables = {var.id: var for var in self._workflow.conversation_variables}
# Find missing variable IDs
missing_ids = set(workflow_variables.keys()) - existing_ids
if not missing_ids:
return existing_variables
# Create missing variables with their default values
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id,
conversation_id=self.conversation.id,
variable=workflow_variables[var_id],
)
for var_id in missing_ids
]
session.add_all(new_variables)
# Return combined list
return existing_variables + new_variables

View File

@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
StreamEvent,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
@ -180,11 +181,15 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel):
continue
status = ModelStatus.ACTIVE
if m.model in model_setting_map:
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED

View File

@ -185,6 +185,6 @@ Clickzetta supports advanced full-text search with multiple analyzers:
## References
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@ -1,9 +1,11 @@
import json
import logging
import queue
import re
import threading
import time
import uuid
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@ -67,6 +69,243 @@ class ClickzettaConfig(BaseModel):
return values
class ClickzettaConnectionPool:
"""
Global connection pool for ClickZetta connections.
Manages connection reuse across ClickzettaVector instances.
"""
_instance: Optional["ClickzettaConnectionPool"] = None
_lock = threading.Lock()
def __init__(self):
self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)]
self._pool_locks: dict[str, threading.Lock] = {}
self._max_pool_size = 5 # Maximum connections per configuration
self._connection_timeout = 300 # 5 minutes timeout
self._cleanup_thread: Optional[threading.Thread] = None
self._shutdown = False
self._start_cleanup_thread()
@classmethod
def get_instance(cls) -> "ClickzettaConnectionPool":
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _get_config_key(self, config: ClickzettaConfig) -> str:
"""Generate unique key for connection configuration."""
return (
f"{config.username}:{config.instance}:{config.service}:"
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
connection = clickzetta.connect(
username=config.username,
password=config.password,
instance=config.instance,
service=config.service,
workspace=config.workspace,
vcluster=config.vcluster,
schema=config.schema_name,
)
# Configure connection session settings
self._configure_connection(connection)
logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries)
return connection
except Exception:
logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries)
if attempt < max_retries - 1:
time.sleep(retry_delay * (2**attempt))
else:
raise
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
def _configure_connection(self, connection: "Connection") -> None:
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
# Temporarily suppress ClickZetta client logging to reduce noise
clickzetta_logger = logging.getLogger("clickzetta")
original_level = clickzetta_logger.level
clickzetta_logger.setLevel(logging.WARNING)
try:
# Use quote mode for string literal escaping
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
# Apply performance optimization hints
performance_hints = [
# Vector index optimization
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
"SET cz.storage.parquet.vector.index.read.local.cache = false",
# Query optimization
"SET cz.sql.table.scan.push.down.filter = true",
"SET cz.sql.table.scan.enable.ensure.filter = true",
"SET cz.storage.always.prefetch.internal = true",
"SET cz.optimizer.generate.columns.always.valid = true",
"SET cz.sql.index.prewhere.enabled = true",
# Storage optimization
"SET cz.storage.parquet.enable.io.prefetch = false",
"SET cz.optimizer.enable.mv.rewrite = false",
"SET cz.sql.dump.as.lz4 = true",
"SET cz.optimizer.limited.optimization.naive.query = true",
"SET cz.sql.table.scan.enable.push.down.log = false",
"SET cz.storage.use.file.format.local.stats = false",
"SET cz.storage.local.file.object.cache.level = all",
# Job execution optimization
"SET cz.sql.job.fast.mode = true",
"SET cz.storage.parquet.non.contiguous.read = true",
"SET cz.sql.compaction.after.commit = true",
]
for hint in performance_hints:
cursor.execute(hint)
finally:
# Restore original logging level
clickzetta_logger.setLevel(original_level)
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
def _is_connection_valid(self, connection: "Connection") -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
cursor.execute("SELECT 1")
return True
except Exception:
return False
def get_connection(self, config: ClickzettaConfig) -> "Connection":
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
# Ensure pool lock exists
if config_key not in self._pool_locks:
with self._lock:
if config_key not in self._pool_locks:
self._pool_locks[config_key] = threading.Lock()
self._pools[config_key] = []
with self._pool_locks[config_key]:
pool = self._pools[config_key]
current_time = time.time()
# Try to reuse existing connection
while pool:
connection, last_used = pool.pop(0)
# Check if connection is not expired and still valid
if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection):
logger.debug("Reusing ClickZetta connection from pool")
return connection
else:
# Connection expired or invalid, close it
try:
connection.close()
except Exception:
pass
# No valid connection found, create new one
return self._create_connection(config)
def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None:
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
if config_key not in self._pool_locks:
# Pool was cleaned up, just close the connection
try:
connection.close()
except Exception:
pass
return
with self._pool_locks[config_key]:
pool = self._pools[config_key]
# Only return to pool if not at capacity and connection is valid
if len(pool) < self._max_pool_size and self._is_connection_valid(connection):
pool.append((connection, time.time()))
logger.debug("Returned ClickZetta connection to pool")
else:
# Pool full or connection invalid, close it
try:
connection.close()
except Exception:
pass
def _cleanup_expired_connections(self) -> None:
"""Clean up expired connections from all pools."""
current_time = time.time()
with self._lock:
for config_key in list(self._pools.keys()):
if config_key not in self._pool_locks:
continue
with self._pool_locks[config_key]:
pool = self._pools[config_key]
valid_connections = []
for connection, last_used in pool:
if current_time - last_used < self._connection_timeout:
valid_connections.append((connection, last_used))
else:
try:
connection.close()
except Exception:
pass
self._pools[config_key] = valid_connections
def _start_cleanup_thread(self) -> None:
"""Start background thread for connection cleanup."""
def cleanup_worker():
while not self._shutdown:
try:
time.sleep(60) # Cleanup every minute
if not self._shutdown:
self._cleanup_expired_connections()
except Exception:
logger.exception("Error in connection pool cleanup")
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
self._cleanup_thread.start()
def shutdown(self) -> None:
"""Shutdown connection pool and close all connections."""
self._shutdown = True
with self._lock:
for config_key in list(self._pools.keys()):
if config_key not in self._pool_locks:
continue
with self._pool_locks[config_key]:
pool = self._pools[config_key]
for connection, _ in pool:
try:
connection.close()
except Exception:
pass
pool.clear()
class ClickzettaVector(BaseVector):
"""
Clickzetta vector storage implementation.
@ -82,71 +321,74 @@ class ClickzettaVector(BaseVector):
super().__init__(collection_name)
self._config = config
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
self._connection: Optional["Connection"] = None
self._init_connection()
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
def _init_connection(self):
"""Initialize Clickzetta connection."""
self._connection = clickzetta.connect(
username=self._config.username,
password=self._config.password,
instance=self._config.instance,
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name
)
def _get_connection(self) -> "Connection":
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
# Set session parameters for better string handling and performance optimization
if self._connection is not None:
with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
def _return_connection(self, connection: "Connection") -> None:
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
# Performance optimization hints for vector operations
self._set_performance_hints(cursor)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
def _set_performance_hints(self, cursor):
"""Set ClickZetta performance optimization hints for vector operations."""
def __init__(self, vector_instance: "ClickzettaVector"):
self.vector = vector_instance
self.connection: Optional[Connection] = None
def __enter__(self) -> "Connection":
self.connection = self.vector._get_connection()
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
if self.connection:
self.vector._return_connection(self.connection)
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
"""Get a connection context manager."""
return self.ConnectionContext(self)
def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict:
"""
Parse metadata from JSON string with proper error handling and fallback.
Args:
raw_metadata: Raw JSON string from database
row_id: Row ID for fallback document_id
Returns:
Parsed metadata dict with guaranteed required fields
"""
try:
# Performance optimization hints for vector operations and query processing
performance_hints = [
# Vector index optimization
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
"SET cz.storage.parquet.vector.index.read.local.cache = false",
if raw_metadata:
metadata = json.loads(raw_metadata)
# Query optimization
"SET cz.sql.table.scan.push.down.filter = true",
"SET cz.sql.table.scan.enable.ensure.filter = true",
"SET cz.storage.always.prefetch.internal = true",
"SET cz.optimizer.generate.columns.always.valid = true",
"SET cz.sql.index.prewhere.enabled = true",
# Handle double-encoded JSON
if isinstance(metadata, str):
metadata = json.loads(metadata)
# Storage optimization
"SET cz.storage.parquet.enable.io.prefetch = false",
"SET cz.optimizer.enable.mv.rewrite = false",
"SET cz.sql.dump.as.lz4 = true",
"SET cz.optimizer.limited.optimization.naive.query = true",
"SET cz.sql.table.scan.enable.push.down.log = false",
"SET cz.storage.use.file.format.local.stats = false",
"SET cz.storage.local.file.object.cache.level = all",
# Ensure we have a dict
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed for metadata")
# Fallback: extract document_id with regex
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "")
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Job execution optimization
"SET cz.sql.job.fast.mode = true",
"SET cz.storage.parquet.non.contiguous.read = true",
"SET cz.sql.compaction.after.commit = true"
]
# Ensure required fields are set
metadata["doc_id"] = row_id # segment id
for hint in performance_hints:
cursor.execute(hint)
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row_id # fallback to segment id
logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints))
except Exception:
# Catch any errors setting performance hints but continue with defaults
logger.exception("Failed to set some performance hints, continuing with default settings")
return metadata
@classmethod
def _init_write_queue(cls):
@ -205,24 +447,33 @@ class ClickzettaVector(BaseVector):
return "clickzetta"
def _ensure_connection(self) -> "Connection":
"""Ensure connection is available and return it."""
if self._connection is None:
raise RuntimeError("Database connection not initialized")
return self._connection
"""Get a connection from the pool."""
return self._get_connection()
def _table_exists(self) -> bool:
"""Check if the table exists."""
try:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
return True
except (RuntimeError, ValueError) as e:
if "table or view not found" in str(e).lower():
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
return True
except Exception as e:
error_message = str(e).lower()
# Handle ClickZetta specific "table or view not found" errors
if any(
phrase in error_message
for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"]
):
logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name)
return False
else:
# Re-raise if it's a different error
raise
# For other connection/permission errors, log warning but return False to avoid blocking cleanup
logger.exception(
"Table existence check failed for %s.%s, assuming it doesn't exist",
self._config.schema_name,
self._table_name,
)
return False
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create the collection and add initial documents."""
@ -254,17 +505,17 @@ class ClickzettaVector(BaseVector):
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
"""
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(create_table_sql)
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(create_table_sql)
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
# Create vector index
self._create_vector_index(cursor)
# Create vector index
self._create_vector_index(cursor)
# Create inverted index for full-text search if enabled
if self._config.enable_inverted_index:
self._create_inverted_index(cursor)
# Create inverted index for full-text search if enabled
if self._config.enable_inverted_index:
self._create_inverted_index(cursor)
def _create_vector_index(self, cursor):
"""Create HNSW vector index for similarity search."""
@ -298,9 +549,7 @@ class ClickzettaVector(BaseVector):
logger.info("Created vector index: %s", index_name)
except (RuntimeError, ValueError) as e:
error_msg = str(e).lower()
if ("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg):
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg:
logger.info("Vector index already exists: %s", e)
else:
logger.exception("Failed to create vector index")
@ -318,9 +567,11 @@ class ClickzettaVector(BaseVector):
for idx in existing_indexes:
idx_str = str(idx).lower()
# More precise check: look for inverted index specifically on the content column
if ("inverted" in idx_str and
Field.CONTENT_KEY.value.lower() in idx_str and
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
return
except (RuntimeError, ValueError) as e:
@ -340,11 +591,12 @@ class ClickzettaVector(BaseVector):
except (RuntimeError, ValueError) as e:
error_msg = str(e).lower()
# Handle ClickZetta specific error messages
if (("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg or
"cannot create inverted index" in error_msg) and
"already has index" in error_msg):
if (
"already exists" in error_msg
or "already has index" in error_msg
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
# Try to get the existing index name for logging
try:
@ -360,7 +612,6 @@ class ClickzettaVector(BaseVector):
logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add documents with embeddings to the collection."""
if not documents:
@ -370,14 +621,20 @@ class ClickzettaVector(BaseVector):
total_batches = (len(documents) + batch_size - 1) // batch_size
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
batch_index: int, batch_size: int, total_batches: int):
def _insert_batch(
self,
batch_docs: list[Document],
batch_embeddings: list[list[float]],
batch_index: int,
batch_size: int,
total_batches: int,
):
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
if not batch_docs or not batch_embeddings:
logger.warning("Empty batch provided, skipping insertion")
@ -411,7 +668,7 @@ class ClickzettaVector(BaseVector):
# According to ClickZetta docs, vector should be formatted as array string
# for external systems: '[1.0, 2.0, 3.0]'
vector_str = '[' + ','.join(map(str, embedding)) + ']'
vector_str = "[" + ",".join(map(str, embedding)) + "]"
data_rows.append([doc_id, content, metadata_json, vector_str])
# Check if we have any valid data to insert
@ -427,37 +684,53 @@ class ClickzettaVector(BaseVector):
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
)
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
# Set session-level hints for batch insert operations
# Note: executemany doesn't support hints parameter, so we set them as session variables
cursor.execute("SET cz.sql.job.fast.mode = true")
cursor.execute("SET cz.sql.compaction.after.commit = true")
cursor.execute("SET cz.storage.always.prefetch.internal = true")
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
try:
# Set session-level hints for batch insert operations
# Note: executemany doesn't support hints parameter, so we set them as session variables
# Temporarily suppress ClickZetta client logging to reduce noise
clickzetta_logger = logging.getLogger("clickzetta")
original_level = clickzetta_logger.level
clickzetta_logger.setLevel(logging.WARNING)
cursor.executemany(insert_sql, data_rows)
logger.info(
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e)
logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None')
raise
try:
cursor.execute("SET cz.sql.job.fast.mode = true")
cursor.execute("SET cz.sql.compaction.after.commit = true")
cursor.execute("SET cz.storage.always.prefetch.internal = true")
finally:
# Restore original logging level
clickzetta_logger.setLevel(original_level)
cursor.executemany(insert_sql, data_rows)
logger.info(
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)",
batch_index // batch_size + 1,
total_batches,
len(data_rows),
vector_dimension,
)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
raise
def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
# Check if table exists first
if not self._table_exists():
return False
safe_id = self._safe_doc_id(id)
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
[safe_id]
)
result = cursor.fetchone()
return result[0] > 0 if result else False
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
binding_params=[safe_id],
)
result = cursor.fetchone()
return result[0] > 0 if result else False
def delete_by_ids(self, ids: list[str]) -> None:
"""Delete documents by IDs."""
@ -475,13 +748,14 @@ class ClickzettaVector(BaseVector):
def _delete_by_ids_impl(self, ids: list[str]) -> None:
"""Implementation of delete by IDs (executed in write worker thread)."""
safe_ids = [self._safe_doc_id(id) for id in ids]
# Create properly escaped string literals for SQL
id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(sql)
# Use parameterized query to prevent SQL injection
placeholders = ",".join("?" for _ in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})"
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(sql, binding_params=safe_ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Delete documents by metadata field."""
@ -495,17 +769,28 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread)."""
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?")
cursor.execute(sql, [value])
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
# Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
)
cursor.execute(sql, binding_params=[value])
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
# Check if table exists first
if not self._table_exists():
logger.warning(
"Table %s.%s does not exist, returning empty results",
self._config.schema_name,
self._table_name,
)
return []
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 0.0)
document_ids_filter = kwargs.get("document_ids_filter")
@ -532,15 +817,15 @@ class ClickzettaVector(BaseVector):
distance_func = "COSINE_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
f"{query_vector_str}) < {2 - score_threshold}")
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
else:
# For L2 distance, smaller is better
distance_func = "L2_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
f"{query_vector_str}) < {score_threshold}")
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
@ -556,55 +841,31 @@ class ClickzettaVector(BaseVector):
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Use hints parameter for vector search optimization
search_hints = {
'hints': {
'sdk.job.timeout': 60, # Increase timeout for vector search
'cz.sql.job.fast.mode': True,
'cz.storage.parquet.vector.index.read.memory.cache': True
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
# Use hints parameter for vector search optimization
search_hints = {
"hints": {
"sdk.job.timeout": 60, # Increase timeout for vector search
"cz.sql.job.fast.mode": True,
"cz.storage.parquet.vector.index.read.memory.cache": True,
}
}
}
cursor.execute(search_sql, parameters=search_hints)
results = cursor.fetchall()
cursor.execute(search_sql, search_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
for row in results:
# Parse metadata using centralized method
metadata = self._parse_metadata(row[2], row[0])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
# Add score based on distance
if self._config.vector_distance_function == "cosine_distance":
metadata["score"] = 1 - (row[3] / 2)
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
metadata["score"] = 1 / (1 + row[3])
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add score based on distance
if self._config.vector_distance_function == "cosine_distance":
metadata["score"] = 1 - (row[3] / 2)
else:
metadata["score"] = 1 / (1 + row[3])
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
return documents
@ -614,6 +875,15 @@ class ClickzettaVector(BaseVector):
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
return []
# Check if table exists first
if not self._table_exists():
logger.warning(
"Table %s.%s does not exist, returning empty results",
self._config.schema_name,
self._table_name,
)
return []
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
@ -649,61 +919,70 @@ class ClickzettaVector(BaseVector):
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
# Use hints parameter for full-text search optimization
fulltext_hints = {
'hints': {
'sdk.job.timeout': 30, # Timeout for full-text search
'cz.sql.job.fast.mode': True,
'cz.sql.index.prewhere.enabled': True
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
try:
# Use hints parameter for full-text search optimization
fulltext_hints = {
"hints": {
"sdk.job.timeout": 30, # Timeout for full-text search
"cz.sql.job.fast.mode": True,
"cz.sql.index.prewhere.enabled": True,
}
}
}
cursor.execute(search_sql, parameters=fulltext_hints)
results = cursor.fetchall()
cursor.execute(search_sql, fulltext_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
except (json.JSONDecodeError, TypeError) as e:
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Add a relevance score for full-text search
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs)
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add a relevance score for full-text search
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs)
return documents
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
"""Fallback search using LIKE operator."""
# Check if table exists first
if not self._table_exists():
logger.warning(
"Table %s.%s does not exist, returning empty results",
self._config.schema_name,
self._table_name,
)
return []
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
@ -735,62 +1014,37 @@ class ClickzettaVector(BaseVector):
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Use hints parameter for LIKE search optimization
like_hints = {
'hints': {
'sdk.job.timeout': 20, # Timeout for LIKE search
'cz.sql.job.fast.mode': True
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
# Use hints parameter for LIKE search optimization
like_hints = {
"hints": {
"sdk.job.timeout": 20, # Timeout for LIKE search
"cz.sql.job.fast.mode": True,
}
}
}
cursor.execute(search_sql, parameters=like_hints)
results = cursor.fetchall()
cursor.execute(search_sql, like_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
for row in results:
# Parse metadata using centralized method
metadata = self._parse_metadata(row[2], row[0])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
metadata["score"] = 0.5 # Lower score for LIKE search
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
metadata["score"] = 0.5 # Lower score for LIKE search
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
return documents
def delete(self) -> None:
"""Delete the entire collection."""
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries."""
return ','.join(map(str, vector))
return ",".join(map(str, vector))
def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
@ -799,13 +1053,12 @@ class ClickzettaVector(BaseVector):
# Remove or replace potentially problematic characters
safe_id = str(doc_id)
# Only allow alphanumeric, hyphens, underscores
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_")
if not safe_id: # If all characters were removed
return str(uuid.uuid4())
return safe_id[:255] # Limit length
class ClickzettaVectorFactory(AbstractVectorFactory):
"""Factory for creating Clickzetta vector instances."""
@ -831,4 +1084,3 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
return ClickzettaVector(collection_name=collection_name, config=config)

View File

@ -246,6 +246,10 @@ class TencentVector(BaseVector):
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
if not self._enable_hybrid_search:
return []
res = self._client.hybrid_search(
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
filter=filter,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

View File

@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
content = self.parse_docx(self.file_path, "storage")
content = self.parse_docx(self.file_path)
return [
Document(
page_content=content,
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return " ".join(paragraph_content) if paragraph_content else ""
def parse_docx(self, docx_path, image_folder):
def parse_docx(self, docx_path):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []

View File

@ -29,7 +29,7 @@ from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatorUserRole
@ -247,7 +247,8 @@ class ToolEngine:
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
else:
result += str(response.message)

View File

@ -1,7 +1,14 @@
import logging
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional
from typing import Optional, cast
from uuid import UUID
import numpy as np
import pytz
from flask_login import current_user
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
if not tz_name:
tz_name = "UTC"
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
)
else:
yield message
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
yield message
else:
yield message

View File

@ -119,6 +119,13 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return super().text
@property
def markdown(self) -> str:
items = []
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return json.dumps(self.value, ensure_ascii=False)

View File

@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
case (
".txt"
| ".markdown"
| ".md"
| ".html"
| ".htm"
| ".xml"
| ".c"
| ".h"
| ".cpp"
| ".hpp"
| ".cc"
| ".cxx"
| ".c++"
| ".py"
| ".js"
| ".ts"
| ".jsx"
| ".tsx"
| ".java"
| ".php"
| ".rb"
| ".go"
| ".rs"
| ".swift"
| ".kt"
| ".scala"
| ".sh"
| ".bash"
| ".bat"
| ".ps1"
| ".sql"
| ".r"
| ".m"
| ".pl"
| ".lua"
| ".vim"
| ".asm"
| ".s"
| ".css"
| ".scss"
| ".less"
| ".sass"
| ".ini"
| ".cfg"
| ".conf"
| ".toml"
| ".env"
| ".log"
| ".vtt"
):
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
case ".vtt":
return _extract_text_from_vtt(file_content)
case ".properties":
return _extract_text_from_properties(file_content)
case _: