mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in self._workflow.conversation_variables
|
||||
]
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
# Initialize conversation variables
|
||||
conversation_variables = self._initialize_conversation_variables()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def _initialize_conversation_variables(self) -> list[VariableUnion]:
|
||||
"""
|
||||
Initialize conversation variables for the current conversation.
|
||||
|
||||
This method:
|
||||
1. Loads existing variables from the database
|
||||
2. Creates new variables if none exist
|
||||
3. Syncs missing variables from the workflow definition
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
# First time initialization - create all variables
|
||||
existing_variables = self._create_all_conversation_variables(session)
|
||||
else:
|
||||
# Check and add any missing variables from the workflow
|
||||
existing_variables = self._sync_missing_conversation_variables(session, existing_variables)
|
||||
|
||||
# Convert to Variable objects for use in the workflow
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[VariableUnion], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
Load existing conversation variables from the database.
|
||||
|
||||
:param session: Database session
|
||||
:return: List of existing conversation variables
|
||||
"""
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
return list(session.scalars(stmt).all())
|
||||
|
||||
def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
Create all conversation variables for a new conversation.
|
||||
|
||||
:param session: Database session
|
||||
:return: List of created conversation variables
|
||||
"""
|
||||
new_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in self._workflow.conversation_variables
|
||||
]
|
||||
|
||||
if new_variables:
|
||||
session.add_all(new_variables)
|
||||
|
||||
return new_variables
|
||||
|
||||
def _sync_missing_conversation_variables(
|
||||
self, session: Session, existing_variables: list[ConversationVariable]
|
||||
) -> list[ConversationVariable]:
|
||||
"""
|
||||
Sync missing conversation variables from the workflow definition.
|
||||
|
||||
This handles the case where new variables are added to a workflow
|
||||
after conversations have already been created.
|
||||
|
||||
:param session: Database session
|
||||
:param existing_variables: List of existing conversation variables
|
||||
:return: Updated list including any newly created variables
|
||||
"""
|
||||
# Get IDs of existing and workflow variables
|
||||
existing_ids = {var.id for var in existing_variables}
|
||||
workflow_variables = {var.id: var for var in self._workflow.conversation_variables}
|
||||
|
||||
# Find missing variable IDs
|
||||
missing_ids = set(workflow_variables.keys()) - existing_ids
|
||||
|
||||
if not missing_ids:
|
||||
return existing_variables
|
||||
|
||||
# Create missing variables with their default values
|
||||
new_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
variable=workflow_variables[var_id],
|
||||
)
|
||||
for var_id in missing_ids
|
||||
]
|
||||
|
||||
session.add_all(new_variables)
|
||||
|
||||
# Return combined list
|
||||
return existing_variables + new_variables
|
||||
|
||||
@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -180,11 +181,15 @@ class MessageCycleManager:
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
||||
@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model in model_setting_map:
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
@ -185,6 +185,6 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
|
||||
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
|
||||
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -67,6 +69,243 @@ class ClickzettaConfig(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
class ClickzettaConnectionPool:
|
||||
"""
|
||||
Global connection pool for ClickZetta connections.
|
||||
Manages connection reuse across ClickzettaVector instances.
|
||||
"""
|
||||
|
||||
_instance: Optional["ClickzettaConnectionPool"] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)]
|
||||
self._pool_locks: dict[str, threading.Lock] = {}
|
||||
self._max_pool_size = 5 # Maximum connections per configuration
|
||||
self._connection_timeout = 300 # 5 minutes timeout
|
||||
self._cleanup_thread: Optional[threading.Thread] = None
|
||||
self._shutdown = False
|
||||
self._start_cleanup_thread()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ClickzettaConnectionPool":
|
||||
"""Get singleton instance of connection pool."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _get_config_key(self, config: ClickzettaConfig) -> str:
|
||||
"""Generate unique key for connection configuration."""
|
||||
return (
|
||||
f"{config.username}:{config.instance}:{config.service}:"
|
||||
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
||||
)
|
||||
|
||||
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
"""Create a new ClickZetta connection."""
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
connection = clickzetta.connect(
|
||||
username=config.username,
|
||||
password=config.password,
|
||||
instance=config.instance,
|
||||
service=config.service,
|
||||
workspace=config.workspace,
|
||||
vcluster=config.vcluster,
|
||||
schema=config.schema_name,
|
||||
)
|
||||
|
||||
# Configure connection session settings
|
||||
self._configure_connection(connection)
|
||||
logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries)
|
||||
return connection
|
||||
except Exception:
|
||||
logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries)
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (2**attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
||||
|
||||
def _configure_connection(self, connection: "Connection") -> None:
|
||||
"""Configure connection session settings."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# Temporarily suppress ClickZetta client logging to reduce noise
|
||||
clickzetta_logger = logging.getLogger("clickzetta")
|
||||
original_level = clickzetta_logger.level
|
||||
clickzetta_logger.setLevel(logging.WARNING)
|
||||
|
||||
try:
|
||||
# Use quote mode for string literal escaping
|
||||
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
|
||||
|
||||
# Apply performance optimization hints
|
||||
performance_hints = [
|
||||
# Vector index optimization
|
||||
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
|
||||
"SET cz.storage.parquet.vector.index.read.local.cache = false",
|
||||
# Query optimization
|
||||
"SET cz.sql.table.scan.push.down.filter = true",
|
||||
"SET cz.sql.table.scan.enable.ensure.filter = true",
|
||||
"SET cz.storage.always.prefetch.internal = true",
|
||||
"SET cz.optimizer.generate.columns.always.valid = true",
|
||||
"SET cz.sql.index.prewhere.enabled = true",
|
||||
# Storage optimization
|
||||
"SET cz.storage.parquet.enable.io.prefetch = false",
|
||||
"SET cz.optimizer.enable.mv.rewrite = false",
|
||||
"SET cz.sql.dump.as.lz4 = true",
|
||||
"SET cz.optimizer.limited.optimization.naive.query = true",
|
||||
"SET cz.sql.table.scan.enable.push.down.log = false",
|
||||
"SET cz.storage.use.file.format.local.stats = false",
|
||||
"SET cz.storage.local.file.object.cache.level = all",
|
||||
# Job execution optimization
|
||||
"SET cz.sql.job.fast.mode = true",
|
||||
"SET cz.storage.parquet.non.contiguous.read = true",
|
||||
"SET cz.sql.compaction.after.commit = true",
|
||||
]
|
||||
|
||||
for hint in performance_hints:
|
||||
cursor.execute(hint)
|
||||
finally:
|
||||
# Restore original logging level
|
||||
clickzetta_logger.setLevel(original_level)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to configure connection, continuing with defaults")
|
||||
|
||||
def _is_connection_valid(self, connection: "Connection") -> bool:
|
||||
"""Check if connection is still valid."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
# Ensure pool lock exists
|
||||
if config_key not in self._pool_locks:
|
||||
with self._lock:
|
||||
if config_key not in self._pool_locks:
|
||||
self._pool_locks[config_key] = threading.Lock()
|
||||
self._pools[config_key] = []
|
||||
|
||||
with self._pool_locks[config_key]:
|
||||
pool = self._pools[config_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Try to reuse existing connection
|
||||
while pool:
|
||||
connection, last_used = pool.pop(0)
|
||||
|
||||
# Check if connection is not expired and still valid
|
||||
if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection):
|
||||
logger.debug("Reusing ClickZetta connection from pool")
|
||||
return connection
|
||||
else:
|
||||
# Connection expired or invalid, close it
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# No valid connection found, create new one
|
||||
return self._create_connection(config)
|
||||
|
||||
def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None:
|
||||
"""Return a connection to the pool."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
if config_key not in self._pool_locks:
|
||||
# Pool was cleaned up, just close the connection
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
with self._pool_locks[config_key]:
|
||||
pool = self._pools[config_key]
|
||||
|
||||
# Only return to pool if not at capacity and connection is valid
|
||||
if len(pool) < self._max_pool_size and self._is_connection_valid(connection):
|
||||
pool.append((connection, time.time()))
|
||||
logger.debug("Returned ClickZetta connection to pool")
|
||||
else:
|
||||
# Pool full or connection invalid, close it
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _cleanup_expired_connections(self) -> None:
|
||||
"""Clean up expired connections from all pools."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
for config_key in list(self._pools.keys()):
|
||||
if config_key not in self._pool_locks:
|
||||
continue
|
||||
|
||||
with self._pool_locks[config_key]:
|
||||
pool = self._pools[config_key]
|
||||
valid_connections = []
|
||||
|
||||
for connection, last_used in pool:
|
||||
if current_time - last_used < self._connection_timeout:
|
||||
valid_connections.append((connection, last_used))
|
||||
else:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._pools[config_key] = valid_connections
|
||||
|
||||
def _start_cleanup_thread(self) -> None:
|
||||
"""Start background thread for connection cleanup."""
|
||||
|
||||
def cleanup_worker():
|
||||
while not self._shutdown:
|
||||
try:
|
||||
time.sleep(60) # Cleanup every minute
|
||||
if not self._shutdown:
|
||||
self._cleanup_expired_connections()
|
||||
except Exception:
|
||||
logger.exception("Error in connection pool cleanup")
|
||||
|
||||
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
self._cleanup_thread.start()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown connection pool and close all connections."""
|
||||
self._shutdown = True
|
||||
|
||||
with self._lock:
|
||||
for config_key in list(self._pools.keys()):
|
||||
if config_key not in self._pool_locks:
|
||||
continue
|
||||
|
||||
with self._pool_locks[config_key]:
|
||||
pool = self._pools[config_key]
|
||||
for connection, _ in pool:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
pool.clear()
|
||||
|
||||
|
||||
class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
Clickzetta vector storage implementation.
|
||||
@ -82,71 +321,74 @@ class ClickzettaVector(BaseVector):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
|
||||
self._connection: Optional["Connection"] = None
|
||||
self._init_connection()
|
||||
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
||||
self._init_write_queue()
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize Clickzetta connection."""
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name
|
||||
)
|
||||
def _get_connection(self) -> "Connection":
|
||||
"""Get a connection from the pool."""
|
||||
return self._connection_pool.get_connection(self._config)
|
||||
|
||||
# Set session parameters for better string handling and performance optimization
|
||||
if self._connection is not None:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use quote mode for string literal escaping to handle quotes better
|
||||
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
|
||||
logger.info("Set string literal escape mode to 'quote' for better quote handling")
|
||||
def _return_connection(self, connection: "Connection") -> None:
|
||||
"""Return a connection to the pool."""
|
||||
self._connection_pool.return_connection(self._config, connection)
|
||||
|
||||
# Performance optimization hints for vector operations
|
||||
self._set_performance_hints(cursor)
|
||||
class ConnectionContext:
|
||||
"""Context manager for borrowing and returning connections."""
|
||||
|
||||
def _set_performance_hints(self, cursor):
|
||||
"""Set ClickZetta performance optimization hints for vector operations."""
|
||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
||||
self.vector = vector_instance
|
||||
self.connection: Optional[Connection] = None
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
self.connection = self.vector._get_connection()
|
||||
return self.connection
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.connection:
|
||||
self.vector._return_connection(self.connection)
|
||||
|
||||
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
|
||||
"""Get a connection context manager."""
|
||||
return self.ConnectionContext(self)
|
||||
|
||||
def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict:
|
||||
"""
|
||||
Parse metadata from JSON string with proper error handling and fallback.
|
||||
|
||||
Args:
|
||||
raw_metadata: Raw JSON string from database
|
||||
row_id: Row ID for fallback document_id
|
||||
|
||||
Returns:
|
||||
Parsed metadata dict with guaranteed required fields
|
||||
"""
|
||||
try:
|
||||
# Performance optimization hints for vector operations and query processing
|
||||
performance_hints = [
|
||||
# Vector index optimization
|
||||
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
|
||||
"SET cz.storage.parquet.vector.index.read.local.cache = false",
|
||||
if raw_metadata:
|
||||
metadata = json.loads(raw_metadata)
|
||||
|
||||
# Query optimization
|
||||
"SET cz.sql.table.scan.push.down.filter = true",
|
||||
"SET cz.sql.table.scan.enable.ensure.filter = true",
|
||||
"SET cz.storage.always.prefetch.internal = true",
|
||||
"SET cz.optimizer.generate.columns.always.valid = true",
|
||||
"SET cz.sql.index.prewhere.enabled = true",
|
||||
# Handle double-encoded JSON
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
# Storage optimization
|
||||
"SET cz.storage.parquet.enable.io.prefetch = false",
|
||||
"SET cz.optimizer.enable.mv.rewrite = false",
|
||||
"SET cz.sql.dump.as.lz4 = true",
|
||||
"SET cz.optimizer.limited.optimization.naive.query = true",
|
||||
"SET cz.sql.table.scan.enable.push.down.log = false",
|
||||
"SET cz.storage.use.file.format.local.stats = false",
|
||||
"SET cz.storage.local.file.object.cache.level = all",
|
||||
# Ensure we have a dict
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.exception("JSON parsing failed for metadata")
|
||||
# Fallback: extract document_id with regex
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "")
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Job execution optimization
|
||||
"SET cz.sql.job.fast.mode = true",
|
||||
"SET cz.storage.parquet.non.contiguous.read = true",
|
||||
"SET cz.sql.compaction.after.commit = true"
|
||||
]
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row_id # segment id
|
||||
|
||||
for hint in performance_hints:
|
||||
cursor.execute(hint)
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row_id # fallback to segment id
|
||||
|
||||
logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints))
|
||||
|
||||
except Exception:
|
||||
# Catch any errors setting performance hints but continue with defaults
|
||||
logger.exception("Failed to set some performance hints, continuing with default settings")
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _init_write_queue(cls):
|
||||
@ -205,24 +447,33 @@ class ClickzettaVector(BaseVector):
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
"""Ensure connection is available and return it."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database connection not initialized")
|
||||
return self._connection
|
||||
"""Get a connection from the pool."""
|
||||
return self._get_connection()
|
||||
|
||||
def _table_exists(self) -> bool:
|
||||
"""Check if the table exists."""
|
||||
try:
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
|
||||
return True
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "table or view not found" in str(e).lower():
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_message = str(e).lower()
|
||||
# Handle ClickZetta specific "table or view not found" errors
|
||||
if any(
|
||||
phrase in error_message
|
||||
for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"]
|
||||
):
|
||||
logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name)
|
||||
return False
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
# For other connection/permission errors, log warning but return False to avoid blocking cleanup
|
||||
logger.exception(
|
||||
"Table existence check failed for %s.%s, assuming it doesn't exist",
|
||||
self._config.schema_name,
|
||||
self._table_name,
|
||||
)
|
||||
return False
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create the collection and add initial documents."""
|
||||
@ -254,17 +505,17 @@ class ClickzettaVector(BaseVector):
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
"""
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(create_table_sql)
|
||||
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(create_table_sql)
|
||||
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
|
||||
|
||||
# Create vector index
|
||||
self._create_vector_index(cursor)
|
||||
# Create vector index
|
||||
self._create_vector_index(cursor)
|
||||
|
||||
# Create inverted index for full-text search if enabled
|
||||
if self._config.enable_inverted_index:
|
||||
self._create_inverted_index(cursor)
|
||||
# Create inverted index for full-text search if enabled
|
||||
if self._config.enable_inverted_index:
|
||||
self._create_inverted_index(cursor)
|
||||
|
||||
def _create_vector_index(self, cursor):
|
||||
"""Create HNSW vector index for similarity search."""
|
||||
@ -298,9 +549,7 @@ class ClickzettaVector(BaseVector):
|
||||
logger.info("Created vector index: %s", index_name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
if ("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg):
|
||||
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg:
|
||||
logger.info("Vector index already exists: %s", e)
|
||||
else:
|
||||
logger.exception("Failed to create vector index")
|
||||
@ -318,9 +567,11 @@ class ClickzettaVector(BaseVector):
|
||||
for idx in existing_indexes:
|
||||
idx_str = str(idx).lower()
|
||||
# More precise check: look for inverted index specifically on the content column
|
||||
if ("inverted" in idx_str and
|
||||
Field.CONTENT_KEY.value.lower() in idx_str and
|
||||
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
|
||||
if (
|
||||
"inverted" in idx_str
|
||||
and Field.CONTENT_KEY.value.lower() in idx_str
|
||||
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
|
||||
):
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
@ -340,11 +591,12 @@ class ClickzettaVector(BaseVector):
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
# Handle ClickZetta specific error messages
|
||||
if (("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg or
|
||||
"cannot create inverted index" in error_msg) and
|
||||
"already has index" in error_msg):
|
||||
if (
|
||||
"already exists" in error_msg
|
||||
or "already has index" in error_msg
|
||||
or "with the same type" in error_msg
|
||||
or "cannot create inverted index" in error_msg
|
||||
) and "already has index" in error_msg:
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
@ -360,7 +612,6 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
@ -370,14 +621,20 @@ class ClickzettaVector(BaseVector):
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i + batch_size]
|
||||
batch_embeddings = embeddings[i:i + batch_size]
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
|
||||
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
|
||||
batch_index: int, batch_size: int, total_batches: int):
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
):
|
||||
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
|
||||
if not batch_docs or not batch_embeddings:
|
||||
logger.warning("Empty batch provided, skipping insertion")
|
||||
@ -411,7 +668,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# According to ClickZetta docs, vector should be formatted as array string
|
||||
# for external systems: '[1.0, 2.0, 3.0]'
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
vector_str = "[" + ",".join(map(str, embedding)) + "]"
|
||||
data_rows.append([doc_id, content, metadata_json, vector_str])
|
||||
|
||||
# Check if we have any valid data to insert
|
||||
@ -427,37 +684,53 @@ class ClickzettaVector(BaseVector):
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
)
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Set session-level hints for batch insert operations
|
||||
# Note: executemany doesn't support hints parameter, so we set them as session variables
|
||||
cursor.execute("SET cz.sql.job.fast.mode = true")
|
||||
cursor.execute("SET cz.sql.compaction.after.commit = true")
|
||||
cursor.execute("SET cz.storage.always.prefetch.internal = true")
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Set session-level hints for batch insert operations
|
||||
# Note: executemany doesn't support hints parameter, so we set them as session variables
|
||||
# Temporarily suppress ClickZetta client logging to reduce noise
|
||||
clickzetta_logger = logging.getLogger("clickzetta")
|
||||
original_level = clickzetta_logger.level
|
||||
clickzetta_logger.setLevel(logging.WARNING)
|
||||
|
||||
cursor.executemany(insert_sql, data_rows)
|
||||
logger.info(
|
||||
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
|
||||
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e)
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None')
|
||||
raise
|
||||
try:
|
||||
cursor.execute("SET cz.sql.job.fast.mode = true")
|
||||
cursor.execute("SET cz.sql.compaction.after.commit = true")
|
||||
cursor.execute("SET cz.storage.always.prefetch.internal = true")
|
||||
finally:
|
||||
# Restore original logging level
|
||||
clickzetta_logger.setLevel(original_level)
|
||||
|
||||
cursor.executemany(insert_sql, data_rows)
|
||||
logger.info(
|
||||
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)",
|
||||
batch_index // batch_size + 1,
|
||||
total_batches,
|
||||
len(data_rows),
|
||||
vector_dimension,
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
raise
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a document exists by ID."""
|
||||
# Check if table exists first
|
||||
if not self._table_exists():
|
||||
return False
|
||||
|
||||
safe_id = self._safe_doc_id(id)
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
|
||||
[safe_id]
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
|
||||
binding_params=[safe_id],
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""Delete documents by IDs."""
|
||||
@ -475,13 +748,14 @@ class ClickzettaVector(BaseVector):
|
||||
def _delete_by_ids_impl(self, ids: list[str]) -> None:
|
||||
"""Implementation of delete by IDs (executed in write worker thread)."""
|
||||
safe_ids = [self._safe_doc_id(id) for id in ids]
|
||||
# Create properly escaped string literals for SQL
|
||||
id_list = ",".join(f"'{id}'" for id in safe_ids)
|
||||
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
# Use parameterized query to prevent SQL injection
|
||||
placeholders = ",".join("?" for _ in safe_ids)
|
||||
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})"
|
||||
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, binding_params=safe_ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field."""
|
||||
@ -495,17 +769,28 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
|
||||
"""Implementation of delete by metadata field (executed in write worker thread)."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Using JSON path to filter with parameterized query
|
||||
# Note: JSON path requires literal key name, cannot be parameterized
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?")
|
||||
cursor.execute(sql, [value])
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Using JSON path to filter with parameterized query
|
||||
# Note: JSON path requires literal key name, cannot be parameterized
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (
|
||||
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
|
||||
)
|
||||
cursor.execute(sql, binding_params=[value])
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
# Check if table exists first
|
||||
if not self._table_exists():
|
||||
logger.warning(
|
||||
"Table %s.%s does not exist, returning empty results",
|
||||
self._config.schema_name,
|
||||
self._table_name,
|
||||
)
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
@ -532,15 +817,15 @@ class ClickzettaVector(BaseVector):
|
||||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {2 - score_threshold}")
|
||||
filter_clauses.append(
|
||||
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
|
||||
)
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {score_threshold}")
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
@ -556,55 +841,31 @@ class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for vector search optimization
|
||||
search_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 60, # Increase timeout for vector search
|
||||
'cz.sql.job.fast.mode': True,
|
||||
'cz.storage.parquet.vector.index.read.memory.cache': True
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for vector search optimization
|
||||
search_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 60, # Increase timeout for vector search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
"cz.storage.parquet.vector.index.read.memory.cache": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=search_hints)
|
||||
results = cursor.fetchall()
|
||||
cursor.execute(search_sql, search_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
for row in results:
|
||||
# Parse metadata using centralized method
|
||||
metadata = self._parse_metadata(row[2], row[0])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
# Add score based on distance
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
metadata["score"] = 1 - (row[3] / 2)
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
metadata["score"] = 1 / (1 + row[3])
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add score based on distance
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
metadata["score"] = 1 - (row[3] / 2)
|
||||
else:
|
||||
metadata["score"] = 1 / (1 + row[3])
|
||||
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
@ -614,6 +875,15 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
|
||||
return []
|
||||
|
||||
# Check if table exists first
|
||||
if not self._table_exists():
|
||||
logger.warning(
|
||||
"Table %s.%s does not exist, returning empty results",
|
||||
self._config.schema_name,
|
||||
self._table_name,
|
||||
)
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
@ -649,61 +919,70 @@ class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Use hints parameter for full-text search optimization
|
||||
fulltext_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 30, # Timeout for full-text search
|
||||
'cz.sql.job.fast.mode': True,
|
||||
'cz.sql.index.prewhere.enabled': True
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Use hints parameter for full-text search optimization
|
||||
fulltext_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 30, # Timeout for full-text search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
"cz.sql.index.prewhere.enabled": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=fulltext_hints)
|
||||
results = cursor.fetchall()
|
||||
cursor.execute(search_sql, fulltext_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Add a relevance score for full-text search
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add a relevance score for full-text search
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
||||
return documents
|
||||
|
||||
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Fallback search using LIKE operator."""
|
||||
# Check if table exists first
|
||||
if not self._table_exists():
|
||||
logger.warning(
|
||||
"Table %s.%s does not exist, returning empty results",
|
||||
self._config.schema_name,
|
||||
self._table_name,
|
||||
)
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
@ -735,62 +1014,37 @@ class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for LIKE search optimization
|
||||
like_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 20, # Timeout for LIKE search
|
||||
'cz.sql.job.fast.mode': True
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for LIKE search optimization
|
||||
like_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 20, # Timeout for LIKE search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=like_hints)
|
||||
results = cursor.fetchall()
|
||||
cursor.execute(search_sql, like_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
for row in results:
|
||||
# Parse metadata using centralized method
|
||||
metadata = self._parse_metadata(row[2], row[0])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
metadata["score"] = 0.5 # Lower score for LIKE search
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
metadata["score"] = 0.5 # Lower score for LIKE search
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
with self.get_connection_context() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
def _format_vector_simple(self, vector: list[float]) -> str:
|
||||
"""Simple vector formatting for SQL queries."""
|
||||
return ','.join(map(str, vector))
|
||||
return ",".join(map(str, vector))
|
||||
|
||||
def _safe_doc_id(self, doc_id: str) -> str:
|
||||
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
|
||||
@ -799,13 +1053,12 @@ class ClickzettaVector(BaseVector):
|
||||
# Remove or replace potentially problematic characters
|
||||
safe_id = str(doc_id)
|
||||
# Only allow alphanumeric, hyphens, underscores
|
||||
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
|
||||
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_")
|
||||
if not safe_id: # If all characters were removed
|
||||
return str(uuid.uuid4())
|
||||
return safe_id[:255] # Limit length
|
||||
|
||||
|
||||
|
||||
class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating Clickzetta vector instances."""
|
||||
|
||||
@ -831,4 +1084,3 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
|
||||
|
||||
return ClickzettaVector(collection_name=collection_name, config=config)
|
||||
|
||||
|
||||
@ -246,6 +246,10 @@ class TencentVector(BaseVector):
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
if not self._enable_hybrid_search:
|
||||
return []
|
||||
res = self._client.hybrid_search(
|
||||
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
|
||||
),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
filter=filter,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
content = self.parse_docx(self.file_path, "storage")
|
||||
content = self.parse_docx(self.file_path)
|
||||
return [
|
||||
Document(
|
||||
page_content=content,
|
||||
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def _parse_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if embed_id:
|
||||
rel_target = run.part.rels[embed_id].target_ref
|
||||
if rel_target in image_map:
|
||||
paragraph_content.append(image_map[rel_target])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return " ".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
def parse_docx(self, docx_path, image_folder):
|
||||
def parse_docx(self, docx_path):
|
||||
doc = DocxDocument(docx_path)
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
|
||||
content = []
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from core.tools.errors import (
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
@ -247,7 +247,8 @@ class ToolEngine:
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
|
||||
if not tz_name:
|
||||
tz_name = "UTC"
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
|
||||
json_msg.json_object = safe_json_value(json_msg.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
|
||||
@ -119,6 +119,13 @@ class ObjectSegment(Segment):
|
||||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return super().text
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
items = []
|
||||
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return json.dumps(self.value, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
|
||||
case (
|
||||
".txt"
|
||||
| ".markdown"
|
||||
| ".md"
|
||||
| ".html"
|
||||
| ".htm"
|
||||
| ".xml"
|
||||
| ".c"
|
||||
| ".h"
|
||||
| ".cpp"
|
||||
| ".hpp"
|
||||
| ".cc"
|
||||
| ".cxx"
|
||||
| ".c++"
|
||||
| ".py"
|
||||
| ".js"
|
||||
| ".ts"
|
||||
| ".jsx"
|
||||
| ".tsx"
|
||||
| ".java"
|
||||
| ".php"
|
||||
| ".rb"
|
||||
| ".go"
|
||||
| ".rs"
|
||||
| ".swift"
|
||||
| ".kt"
|
||||
| ".scala"
|
||||
| ".sh"
|
||||
| ".bash"
|
||||
| ".bat"
|
||||
| ".ps1"
|
||||
| ".sql"
|
||||
| ".r"
|
||||
| ".m"
|
||||
| ".pl"
|
||||
| ".lua"
|
||||
| ".vim"
|
||||
| ".asm"
|
||||
| ".s"
|
||||
| ".css"
|
||||
| ".scss"
|
||||
| ".less"
|
||||
| ".sass"
|
||||
| ".ini"
|
||||
| ".cfg"
|
||||
| ".conf"
|
||||
| ".toml"
|
||||
| ".env"
|
||||
| ".log"
|
||||
| ".vtt"
|
||||
):
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
return _extract_text_from_msg(file_content)
|
||||
case ".vtt":
|
||||
return _extract_text_from_vtt(file_content)
|
||||
case ".properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
|
||||
Reference in New Issue
Block a user