mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/rag-2
# Conflicts: # web/app/components/workflow/hooks/use-workflow.ts
This commit is contained in:
@ -22,22 +22,50 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
password: str
|
||||
# Regular Elasticsearch config
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
# Elastic Cloud specific config
|
||||
cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud
|
||||
api_key: Optional[str] = None
|
||||
|
||||
# Common config
|
||||
use_cloud: bool = False
|
||||
ca_certs: Optional[str] = None
|
||||
verify_certs: bool = False
|
||||
request_timeout: int = 100000
|
||||
retry_on_timeout: bool = True
|
||||
max_retries: int = 10000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config PORT is required")
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
use_cloud = values.get("use_cloud", False)
|
||||
cloud_url = values.get("cloud_url")
|
||||
|
||||
if use_cloud:
|
||||
# Cloud configuration validation - requires cloud_url and api_key
|
||||
if not cloud_url:
|
||||
raise ValueError("cloud_url is required for Elastic Cloud")
|
||||
|
||||
api_key = values.get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required for Elastic Cloud")
|
||||
|
||||
else:
|
||||
# Regular Elasticsearch validation
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOST is required for regular Elasticsearch")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config PORT is required for regular Elasticsearch")
|
||||
if not values.get("username"):
|
||||
raise ValueError("config USERNAME is required for regular Elasticsearch")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config PASSWORD is required for regular Elasticsearch")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -50,21 +78,69 @@ class ElasticSearchVector(BaseVector):
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
"""
|
||||
Initialize Elasticsearch client for both regular Elasticsearch and Elastic Cloud.
|
||||
"""
|
||||
try:
|
||||
parsed_url = urlparse(config.host)
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
# Check if using Elastic Cloud
|
||||
client_config: dict[str, Any]
|
||||
if config.use_cloud and config.cloud_url:
|
||||
client_config = {
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
"verify_certs": config.verify_certs,
|
||||
}
|
||||
|
||||
# Parse cloud URL and configure hosts
|
||||
parsed_url = urlparse(config.cloud_url)
|
||||
host = f"{parsed_url.scheme}://{parsed_url.hostname}"
|
||||
if parsed_url.port:
|
||||
host += f":{parsed_url.port}"
|
||||
|
||||
client_config["hosts"] = [host]
|
||||
|
||||
# API key authentication for cloud
|
||||
client_config["api_key"] = config.api_key
|
||||
|
||||
# SSL settings
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
else:
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
client = Elasticsearch(
|
||||
hosts=hosts,
|
||||
basic_auth=(config.username, config.password),
|
||||
request_timeout=100000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10000,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
# Regular Elasticsearch configuration
|
||||
parsed_url = urlparse(config.host or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (config.username, config.password),
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
if use_https:
|
||||
client_config["verify_certs"] = config.verify_certs
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
# Test connection
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return client
|
||||
|
||||
@ -209,7 +285,11 @@ class ElasticSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||
logger.info("Created index %s with dimension %s", self._collection_name, dim)
|
||||
else:
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@ -225,13 +305,51 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
# Check if ELASTICSEARCH_USE_CLOUD is explicitly set to false (boolean)
|
||||
use_cloud_env = config.get("ELASTICSEARCH_USE_CLOUD", False)
|
||||
|
||||
if use_cloud_env is False:
|
||||
# Use regular Elasticsearch with config values
|
||||
config_dict = {
|
||||
"use_cloud": False,
|
||||
"host": config.get("ELASTICSEARCH_HOST", "elasticsearch"),
|
||||
"port": config.get("ELASTICSEARCH_PORT", 9200),
|
||||
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": config.get("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
}
|
||||
else:
|
||||
# Check for cloud configuration
|
||||
cloud_url = config.get("ELASTICSEARCH_CLOUD_URL")
|
||||
if cloud_url:
|
||||
config_dict = {
|
||||
"use_cloud": True,
|
||||
"cloud_url": cloud_url,
|
||||
"api_key": config.get("ELASTICSEARCH_API_KEY"),
|
||||
}
|
||||
else:
|
||||
# Fallback to regular Elasticsearch
|
||||
config_dict = {
|
||||
"use_cloud": False,
|
||||
"host": config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
"port": config.get("ELASTICSEARCH_PORT", 9200),
|
||||
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
}
|
||||
|
||||
# Common configuration
|
||||
config_dict.update(
|
||||
{
|
||||
"ca_certs": str(config.get("ELASTICSEARCH_CA_CERTS")) if config.get("ELASTICSEARCH_CA_CERTS") else None,
|
||||
"verify_certs": bool(config.get("ELASTICSEARCH_VERIFY_CERTS", False)),
|
||||
"request_timeout": int(config.get("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"retry_on_timeout": bool(config.get("ELASTICSEARCH_RETRY_ON_TIMEOUT", True)),
|
||||
"max_retries": int(config.get("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
}
|
||||
)
|
||||
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
config=ElasticSearchConfig(**config_dict),
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
@ -132,13 +133,15 @@ class NotionExtractor(BaseExtractor):
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)):
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
if "url" in result:
|
||||
row_content = row_content + f"Row Page URL:{result.get('url', '')}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
has_more = response_data.get("has_more", False)
|
||||
|
||||
Reference in New Issue
Block a user