mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-06 08:06:43 +08:00
### What problem does this PR solve? When multiple columns are used as content columns in RDBMS connector, the generated document text gets chunked by TxtParser which strips newline delimiters during merge. This causes field names and values from different columns to be concatenated without any separator, making the content unreadable. Changes: - txt_parser.py: restore newline separator when merging adjacent text segments within a chunk, so that split sections are not directly concatenated - rdbms_connector.py: use double newline between fields and place field value on a new line after the field name bracket, giving TxtParser clearer boundaries to work with Closes #13001 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: tunsuytang <tunsuytang@tencent.com>
407 lines
15 KiB
Python
407 lines
15 KiB
Python
"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, Dict, Generator, Optional, Union
|
|
|
|
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
)
|
|
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
|
from common.data_source.models import Document
|
|
|
|
|
|
class DatabaseType(str, Enum):
|
|
"""Supported database types."""
|
|
MYSQL = "mysql"
|
|
POSTGRESQL = "postgresql"
|
|
|
|
|
|
class RDBMSConnector(LoadConnector, PollConnector):
|
|
"""
|
|
RDBMS connector for importing data from MySQL and PostgreSQL databases.
|
|
|
|
This connector allows users to:
|
|
1. Connect to a MySQL or PostgreSQL database
|
|
2. Execute a SQL query to extract data
|
|
3. Map columns to content (for vectorization) and metadata
|
|
4. Sync data in batch or incremental mode using a timestamp column
|
|
"""
|
|
def __init__(
|
|
self,
|
|
db_type: str,
|
|
host: str,
|
|
port: int,
|
|
database: str,
|
|
query: str,
|
|
content_columns: str,
|
|
metadata_columns: Optional[str] = None,
|
|
id_column: Optional[str] = None,
|
|
timestamp_column: Optional[str] = None,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
"""
|
|
Initialize the RDBMS connector.
|
|
|
|
Args:
|
|
db_type: Database type ('mysql' or 'postgresql')
|
|
host: Database host
|
|
port: Database port
|
|
database: Database name
|
|
query: SQL query to execute (e.g., "SELECT * FROM products WHERE status = 'active'")
|
|
content_columns: Comma-separated column names to use for document content
|
|
metadata_columns: Comma-separated column names to use as metadata (optional)
|
|
id_column: Column to use as unique document ID (optional, will generate hash if not provided)
|
|
timestamp_column: Column to use for incremental sync (optional, must be datetime/timestamp type)
|
|
batch_size: Number of documents per batch
|
|
"""
|
|
self.db_type = DatabaseType(db_type.lower())
|
|
self.host = host.strip()
|
|
self.port = port
|
|
self.database = database.strip()
|
|
self.query = query.strip()
|
|
self.content_columns = [c.strip() for c in content_columns.split(",") if c.strip()]
|
|
self.metadata_columns = [c.strip() for c in (metadata_columns or "").split(",") if c.strip()]
|
|
self.id_column = id_column.strip() if id_column else None
|
|
self.timestamp_column = timestamp_column.strip() if timestamp_column else None
|
|
self.batch_size = batch_size
|
|
|
|
self._connection = None
|
|
self._credentials: Dict[str, Any] = {}
|
|
|
|
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
|
"""Load database credentials."""
|
|
logging.debug(f"Loading credentials for {self.db_type} database: {self.database}")
|
|
|
|
required_keys = ["username", "password"]
|
|
for key in required_keys:
|
|
if not credentials.get(key):
|
|
raise ConnectorMissingCredentialError(f"RDBMS ({self.db_type}): missing {key}")
|
|
|
|
self._credentials = credentials
|
|
return None
|
|
|
|
def _get_connection(self):
|
|
"""Create and return a database connection."""
|
|
if self._connection is not None:
|
|
return self._connection
|
|
|
|
username = self._credentials.get("username")
|
|
password = self._credentials.get("password")
|
|
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
try:
|
|
import mysql.connector
|
|
except ImportError:
|
|
raise ConnectorValidationError(
|
|
"MySQL connector not installed. Please install mysql-connector-python."
|
|
)
|
|
try:
|
|
self._connection = mysql.connector.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
database=self.database,
|
|
user=username,
|
|
password=password,
|
|
charset='utf8mb4',
|
|
use_unicode=True,
|
|
)
|
|
except Exception as e:
|
|
raise ConnectorValidationError(f"Failed to connect to MySQL: {e}")
|
|
elif self.db_type == DatabaseType.POSTGRESQL:
|
|
try:
|
|
import psycopg2
|
|
except ImportError:
|
|
raise ConnectorValidationError(
|
|
"PostgreSQL connector not installed. Please install psycopg2-binary."
|
|
)
|
|
try:
|
|
self._connection = psycopg2.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
dbname=self.database,
|
|
user=username,
|
|
password=password,
|
|
)
|
|
except Exception as e:
|
|
raise ConnectorValidationError(f"Failed to connect to PostgreSQL: {e}")
|
|
|
|
return self._connection
|
|
|
|
def _close_connection(self):
|
|
"""Close the database connection."""
|
|
if self._connection is not None:
|
|
try:
|
|
self._connection.close()
|
|
except Exception:
|
|
pass
|
|
self._connection = None
|
|
|
|
def _get_tables(self) -> list[str]:
|
|
"""Get list of all tables in the database."""
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
try:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
cursor.execute("SHOW TABLES")
|
|
else:
|
|
cursor.execute(
|
|
"SELECT table_name FROM information_schema.tables "
|
|
"WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
|
|
)
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
return tables
|
|
finally:
|
|
cursor.close()
|
|
|
|
def _build_query_with_time_filter(
|
|
self,
|
|
start: Optional[datetime] = None,
|
|
end: Optional[datetime] = None,
|
|
) -> str:
|
|
"""Build the query with optional time filtering for incremental sync."""
|
|
if not self.query:
|
|
return "" # Will be handled by table discovery
|
|
base_query = self.query.rstrip(";")
|
|
|
|
if not self.timestamp_column or (start is None and end is None):
|
|
return base_query
|
|
|
|
has_where = "where" in base_query.lower()
|
|
connector = " AND" if has_where else " WHERE"
|
|
|
|
time_conditions = []
|
|
if start is not None:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
|
|
else:
|
|
time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
|
|
|
|
if end is not None:
|
|
if self.db_type == DatabaseType.MYSQL:
|
|
time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
|
|
else:
|
|
time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
|
|
|
|
if time_conditions:
|
|
return f"{base_query}{connector} {' AND '.join(time_conditions)}"
|
|
|
|
return base_query
|
|
|
|
def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
|
|
"""Convert a database row to a Document."""
|
|
row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
|
|
|
|
content_parts = []
|
|
for col in self.content_columns:
|
|
if col in row_dict and row_dict[col] is not None:
|
|
value = row_dict[col]
|
|
if isinstance(value, (dict, list)):
|
|
value = json.dumps(value, ensure_ascii=False)
|
|
# Use brackets around field name and put value on a new line
|
|
# so that TxtParser preserves field boundaries after chunking.
|
|
content_parts.append(f"【{col}】:\n{value}")
|
|
|
|
content = "\n\n".join(content_parts)
|
|
|
|
if self.id_column and self.id_column in row_dict:
|
|
doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
|
|
else:
|
|
content_hash = hashlib.md5(content.encode()).hexdigest()
|
|
doc_id = f"{self.db_type}:{self.database}:{content_hash}"
|
|
|
|
metadata = {}
|
|
for col in self.metadata_columns:
|
|
if col in row_dict and row_dict[col] is not None:
|
|
value = row_dict[col]
|
|
if isinstance(value, datetime):
|
|
value = value.isoformat()
|
|
elif isinstance(value, (dict, list)):
|
|
value = json.dumps(value, ensure_ascii=False)
|
|
else:
|
|
value = str(value)
|
|
metadata[col] = value
|
|
|
|
doc_updated_at = datetime.now(timezone.utc)
|
|
if self.timestamp_column and self.timestamp_column in row_dict:
|
|
ts_value = row_dict[self.timestamp_column]
|
|
if isinstance(ts_value, datetime):
|
|
if ts_value.tzinfo is None:
|
|
doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
|
|
else:
|
|
doc_updated_at = ts_value
|
|
|
|
first_content_col = self.content_columns[0] if self.content_columns else "record"
|
|
semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100]
|
|
|
|
|
|
return Document(
|
|
id=doc_id,
|
|
blob=content.encode("utf-8"),
|
|
source=DocumentSource(self.db_type.value),
|
|
semantic_identifier=semantic_id,
|
|
extension=".txt",
|
|
doc_updated_at=doc_updated_at,
|
|
size_bytes=len(content.encode("utf-8")),
|
|
metadata=metadata if metadata else None,
|
|
)
|
|
|
|
def _yield_documents_from_query(
|
|
self,
|
|
query: str,
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Generate documents from a single query."""
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
try:
|
|
logging.info(f"Executing query: {query[:200]}...")
|
|
cursor.execute(query)
|
|
column_names = [desc[0] for desc in cursor.description]
|
|
|
|
batch: list[Document] = []
|
|
for row in cursor:
|
|
try:
|
|
doc = self._row_to_document(row, column_names)
|
|
batch.append(doc)
|
|
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
except Exception as e:
|
|
logging.warning(f"Error converting row to document: {e}")
|
|
continue
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
finally:
|
|
try:
|
|
cursor.fetchall()
|
|
except Exception:
|
|
pass
|
|
cursor.close()
|
|
|
|
def _yield_documents(
|
|
self,
|
|
start: Optional[datetime] = None,
|
|
end: Optional[datetime] = None,
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Generate documents from database query results."""
|
|
if self.query:
|
|
query = self._build_query_with_time_filter(start, end)
|
|
yield from self._yield_documents_from_query(query)
|
|
else:
|
|
tables = self._get_tables()
|
|
logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
|
|
for table in tables:
|
|
query = f"SELECT * FROM {table}"
|
|
logging.info(f"Loading table: {table}")
|
|
yield from self._yield_documents_from_query(query)
|
|
|
|
self._close_connection()
|
|
|
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
|
"""Load all documents from the database (full sync)."""
|
|
logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
|
|
return self._yield_documents()
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> Generator[list[Document], None, None]:
|
|
"""Poll for new/updated documents since the last sync (incremental sync)."""
|
|
if not self.timestamp_column:
|
|
logging.warning(
|
|
"No timestamp column configured for incremental sync. "
|
|
"Falling back to full sync."
|
|
)
|
|
return self.load_from_state()
|
|
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
|
|
|
logging.debug(
|
|
f"Polling {self.db_type} database {self.database} "
|
|
f"from {start_datetime} to {end_datetime}"
|
|
)
|
|
|
|
return self._yield_documents(start_datetime, end_datetime)
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
"""Validate connector settings by testing the connection."""
|
|
if not self._credentials:
|
|
raise ConnectorMissingCredentialError("RDBMS credentials not loaded.")
|
|
|
|
if not self.host:
|
|
raise ConnectorValidationError("Database host is required.")
|
|
|
|
if not self.database:
|
|
raise ConnectorValidationError("Database name is required.")
|
|
|
|
if not self.content_columns:
|
|
raise ConnectorValidationError(
|
|
"At least one content column must be specified."
|
|
)
|
|
|
|
try:
|
|
connection = self._get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
test_query = "SELECT 1"
|
|
cursor.execute(test_query)
|
|
cursor.fetchone()
|
|
cursor.close()
|
|
|
|
logging.info(f"Successfully connected to {self.db_type} database: {self.database}")
|
|
|
|
except ConnectorValidationError:
|
|
self._close_connection()
|
|
raise
|
|
except Exception as e:
|
|
self._close_connection()
|
|
raise ConnectorValidationError(
|
|
f"Failed to connect to {self.db_type} database: {str(e)}"
|
|
)
|
|
finally:
|
|
self._close_connection()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
|
|
credentials_dict = {
|
|
"username": os.environ.get("DB_USERNAME", "root"),
|
|
"password": os.environ.get("DB_PASSWORD", ""),
|
|
}
|
|
|
|
connector = RDBMSConnector(
|
|
db_type="mysql",
|
|
host=os.environ.get("DB_HOST", "localhost"),
|
|
port=int(os.environ.get("DB_PORT", "3306")),
|
|
database=os.environ.get("DB_NAME", "test"),
|
|
query="SELECT * FROM products LIMIT 10",
|
|
content_columns="name,description",
|
|
metadata_columns="id,category,price",
|
|
id_column="id",
|
|
timestamp_column="updated_at",
|
|
)
|
|
|
|
try:
|
|
connector.load_credentials(credentials_dict)
|
|
connector.validate_connector_settings()
|
|
|
|
for batch in connector.load_from_state():
|
|
print(f"Batch of {len(batch)} documents:")
|
|
for doc in batch:
|
|
print(f" - {doc.id}: {doc.semantic_identifier}")
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|