mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Compare commits
20 Commits
build/rele
...
fix/1.13.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 138d497418 | |||
| 5760914fb3 | |||
| 8f79989172 | |||
| 9d7ea953ea | |||
| 5b5b21502b | |||
| 44c356258f | |||
| 44fb3cd2af | |||
| 2bf6728951 | |||
| fcfa11a71a | |||
| 1730f900c1 | |||
| 12178e7aec | |||
| afe23a029b | |||
| c8560bacb3 | |||
| 0f1b8bf5f9 | |||
| 652211ad96 | |||
| c049249bc1 | |||
| 138083dfc8 | |||
| d1961c261e | |||
| a717519822 | |||
| a592c53573 |
@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -321,20 +321,6 @@ CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
|
||||
@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
@ -608,7 +608,7 @@ def migrate_oss(
|
||||
click.style(
|
||||
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
|
||||
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs.",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
@ -155,11 +155,9 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.MATRIXONE,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.HOLOGRES,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
|
||||
@ -11,7 +11,6 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
@ -20,10 +19,8 @@ from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.hologres_config import HologresConfig
|
||||
@ -41,7 +38,6 @@ from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tablestore_config import TableStoreConfig
|
||||
from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
@ -58,7 +54,6 @@ class StorageConfig(BaseSettings):
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
@ -69,7 +64,7 @@ class StorageConfig(BaseSettings):
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
@ -334,7 +329,6 @@ class MiddlewareConfig(
|
||||
AliyunOSSStorageConfig,
|
||||
AzureBlobStorageConfig,
|
||||
BaiduOBSStorageConfig,
|
||||
ClickZettaVolumeStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
@ -345,9 +339,7 @@ class MiddlewareConfig(
|
||||
VolcengineTOSStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HologresConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
@ -374,7 +366,6 @@ class MiddlewareConfig(
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
MatrixoneConfig,
|
||||
):
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@ -1,63 +0,0 @@
|
||||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: str | None = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: str | None = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: str | None = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: str | None = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
@ -1,49 +0,0 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID: str | None = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET: str | None = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID: str | None = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to.",
|
||||
)
|
||||
ANALYTICDB_ACCOUNT: str | None = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance"
|
||||
" (usually the initial account created with the instance).",
|
||||
)
|
||||
ANALYTICDB_PASSWORD: str | None = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for database authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE: str | None = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: str | None = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
ANALYTICDB_HOST: str | None = Field(
|
||||
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_PORT: PositiveInt = Field(
|
||||
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||
@ -1,68 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseSettings):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: str | None = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: str | None = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: str | None = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: str | None = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: str | None = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: str | None = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: str | None = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: int | None = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: bool | None = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: str | None = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: str | None = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: str | None = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
@ -1,33 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TableStoreConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for TableStore.
|
||||
"""
|
||||
|
||||
TABLESTORE_ENDPOINT: str | None = Field(
|
||||
description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_INSTANCE_NAME: str | None = Field(
|
||||
description="Instance name to access TableStore server (eg. 'instance-name')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_ACCESS_KEY_ID: str | None = Field(
|
||||
description="AccessKey id for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_ACCESS_KEY_SECRET: str | None = Field(
|
||||
description="AccessKey secret for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
|
||||
description="Whether to normalize full-text search scores to [0, 1]",
|
||||
default=False,
|
||||
)
|
||||
@ -242,7 +242,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
@ -255,11 +254,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.SEEKDB,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
|
||||
@ -297,6 +297,7 @@ class DatasetDocumentListApi(Resource):
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
@ -473,9 +473,21 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
||||
@ -11,7 +11,6 @@ class TracingProviderEnum(StrEnum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
MLFLOW = "mlflow"
|
||||
DATABRICKS = "databricks"
|
||||
@ -145,31 +144,6 @@ class OpikConfig(BaseTracingConfig):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
|
||||
@ -76,16 +76,6 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace")
|
||||
attributes: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
exception: str | None = Field(None, description="Exception message of the trace")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
@ -1,523 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from weave.trace_server.trace_server_interface import (
|
||||
CallEndReq,
|
||||
CallStartReq,
|
||||
EndedCallSchemaForInsert,
|
||||
StartedCallSchemaForInsert,
|
||||
SummaryInsertMap,
|
||||
TraceStatus,
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaveDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
weave_config: WeaveConfig,
|
||||
):
|
||||
super().__init__(weave_config)
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
|
||||
# Then initialize weave client
|
||||
self.weave_client = weave.init(
|
||||
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.debug("Trace info: %s", trace_info)
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
|
||||
if trace_info.message_id:
|
||||
message_attributes = trace_info.metadata
|
||||
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
message_attributes["message_id"] = trace_info.message_id
|
||||
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
message_attributes["trace_id"] = trace_id
|
||||
message_attributes["start_time"] = trace_info.start_time
|
||||
message_attributes["end_time"] = trace_info.end_time
|
||||
message_attributes["tags"] = ["message", "workflow"]
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
attributes=message_attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(message_run)
|
||||
|
||||
workflow_attributes = trace_info.metadata
|
||||
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
workflow_attributes["trace_id"] = trace_id
|
||||
workflow_attributes["start_time"] = trace_info.start_time
|
||||
workflow_attributes["end_time"] = trace_info.end_time
|
||||
workflow_attributes["tags"] = ["dify_workflow"]
|
||||
|
||||
workflow_run = WeaveTraceModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
# rearrange workflow_node_executions by starting time
|
||||
workflow_node_executions = sorted(workflow_node_executions, key=lambda x: x.created_at)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
attributes["tags"] = ["node_execution"]
|
||||
attributes["start_time"] = created_at
|
||||
attributes["end_time"] = finished_at
|
||||
attributes["elapsed_time"] = elapsed_time
|
||||
attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
attributes["trace_id"] = trace_id
|
||||
node_run = WeaveTraceModel(
|
||||
total_tokens=node_total_tokens,
|
||||
op=node_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
attributes=attributes,
|
||||
id=node_execution_id,
|
||||
exception=None,
|
||||
)
|
||||
|
||||
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(node_run)
|
||||
|
||||
self.finish_call(workflow_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
attributes = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
attributes["end_user_id"] = end_user_id
|
||||
|
||||
attributes["message_id"] = message_id
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
exception=trace_info.error,
|
||||
file_list=file_list,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.start_call(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
op="llm",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
file_list=[],
|
||||
exception=None,
|
||||
)
|
||||
self.start_call(
|
||||
llm_run,
|
||||
parent_run_id=trace_id,
|
||||
)
|
||||
self.finish_call(llm_run)
|
||||
self.finish_call(message_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["moderation"]
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(moderation_run, parent_run_id=trace_id)
|
||||
self.finish_call(moderation_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["suggested_question"]
|
||||
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_id)
|
||||
self.finish_call(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["dataset_retrieval"]
|
||||
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_id)
|
||||
self.finish_call(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["tool", trace_info.tool_name]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
tool_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
self.start_call(tool_run, parent_run_id=trace_id)
|
||||
self.finish_call(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["generate_name"]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(name_run)
|
||||
self.finish_call(name_run)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def _normalize_time(self, dt: datetime | None) -> datetime:
|
||||
if dt is None:
|
||||
return datetime.now(UTC)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
|
||||
if trace_id is None:
|
||||
trace_id = run_data.id
|
||||
|
||||
call_start_req = CallStartReq(
|
||||
start=StartedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
op_name=str(run_data.op),
|
||||
trace_id=trace_id,
|
||||
parent_id=parent_run_id,
|
||||
started_at=started_at,
|
||||
attributes=attributes,
|
||||
inputs=inputs,
|
||||
wb_user_id=None,
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_start(call_start_req)
|
||||
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call_meta = self.calls.get(run_data.id)
|
||||
if not call_meta:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
|
||||
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
|
||||
if elapsed_ms < 0:
|
||||
elapsed_ms = 0
|
||||
|
||||
status_counts = {
|
||||
TraceStatus.SUCCESS: 0,
|
||||
TraceStatus.ERROR: 0,
|
||||
}
|
||||
if run_data.exception:
|
||||
status_counts[TraceStatus.ERROR] = 1
|
||||
else:
|
||||
status_counts[TraceStatus.SUCCESS] = 1
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"status_counts": status_counts,
|
||||
"weave": {"latency_ms": elapsed_ms},
|
||||
}
|
||||
|
||||
exception_str = str(run_data.exception) if run_data.exception else None
|
||||
|
||||
call_end_req = CallEndReq(
|
||||
end=EndedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
ended_at=ended_at,
|
||||
exception=exception_str,
|
||||
output=run_data.outputs,
|
||||
summary=cast(SummaryInsertMap, summary),
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_end(call_end_req)
|
||||
@ -196,6 +196,8 @@ class ProviderManager:
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
api_config: AnalyticdbVectorOpenAPIConfig | None,
|
||||
sql_config: AnalyticdbVectorBySqlConfig | None,
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
|
||||
collection_name, api_config
|
||||
)
|
||||
else:
|
||||
if sql_config is None:
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self):
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID or "",
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
)
|
||||
sqlConfig = None
|
||||
else:
|
||||
# implemented through sql
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
apiConfig,
|
||||
sqlConfig,
|
||||
)
|
||||
@ -1,321 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str | None = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client # type: ignore
|
||||
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException # type: ignore
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
if doc.metadata is not None:
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@ -1,275 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self):
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self):
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
conn = cur.connection
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise RuntimeError(
|
||||
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
|
||||
) from e
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
try:
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
_, vector, score, page_content, metadata = record
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
ORDER BY score DESC, id DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
_, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@ -1,201 +0,0 @@
|
||||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
1. **english**: Designed for English text
|
||||
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
1. **chinese**: Chinese text tokenizer
|
||||
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
1. **unicode**: Multi-language tokenizer based on Unicode
|
||||
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
1. **Use appropriate distance functions**:
|
||||
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
1. **Combine with vector search**:
|
||||
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
1. Check network connectivity to Clickzetta service
|
||||
1. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
1. Check if vector index is being used:
|
||||
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
1. Check analyzer configuration matches your content language
|
||||
1. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
1. Full-text search relevance scores are not provided by Clickzetta
|
||||
1. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
1. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
1. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
@ -1 +0,0 @@
|
||||
# Clickzetta Vector Database Integration for Dify
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,413 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import tablestore # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableStoreConfig(BaseModel):
|
||||
access_key_id: str | None = None
|
||||
access_key_secret: str | None = None
|
||||
instance_name: str | None = None
|
||||
endpoint: str | None = None
|
||||
normalize_full_text_bm25_score: bool | None = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ACCESS_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ACCESS_KEY_SECRET is required")
|
||||
if not values["instance_name"]:
|
||||
raise ValueError("config INSTANCE_NAME is required")
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class TableStoreVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: TableStoreConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._tablestore_client = tablestore.OTSClient(
|
||||
config.endpoint,
|
||||
config.access_key_id,
|
||||
config.access_key_secret,
|
||||
config.instance_name,
|
||||
)
|
||||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY}_tags"
|
||||
|
||||
def create_collection(self, embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
docs = []
|
||||
request = BatchGetRowRequest()
|
||||
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
|
||||
rows_to_get = [[("id", _id)] for _id in ids]
|
||||
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
|
||||
|
||||
result = self._tablestore_client.batch_get_row(request)
|
||||
table_result = result.get_result_by_table(self._table_name)
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TABLESTORE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
|
||||
for i in range(len(documents)):
|
||||
self._write_row(
|
||||
primary_key=uuids[i],
|
||||
attributes={
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
)
|
||||
return uuids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
result = self._tablestore_client.get_row(
|
||||
table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"]
|
||||
)
|
||||
assert isinstance(result, tuple | list)
|
||||
# Unpack the tuple result
|
||||
_, return_row, _ = result
|
||||
|
||||
return return_row is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._delete_row(id=id)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
return self._search_by_metadata(key, value)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
|
||||
|
||||
def delete(self):
|
||||
self._delete_table_if_exist()
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
return
|
||||
|
||||
self._create_table_if_not_exist()
|
||||
self._create_search_index_if_not_exist(dimension)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_table_if_not_exist(self):
|
||||
table_list = self._tablestore_client.list_table()
|
||||
if self._table_name in table_list:
|
||||
logger.info("Tablestore system table[%s] already exists", self._table_name)
|
||||
return None
|
||||
|
||||
schema_of_primary_key = [("id", "STRING")]
|
||||
table_meta = tablestore.TableMeta(self._table_name, schema_of_primary_key)
|
||||
table_options = tablestore.TableOptions()
|
||||
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
|
||||
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
|
||||
logger.info("Tablestore create table[%s] successfully.", self._table_name)
|
||||
|
||||
def _create_search_index_if_not_exist(self, dimension: int):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
if self._index_name in [t[1] for t in search_index_list]:
|
||||
logger.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
return None
|
||||
|
||||
field_schemas = [
|
||||
tablestore.FieldSchema(
|
||||
Field.CONTENT_KEY,
|
||||
tablestore.FieldType.TEXT,
|
||||
analyzer=tablestore.AnalyzerType.MAXWORD,
|
||||
index=True,
|
||||
enable_sort_and_agg=False,
|
||||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.VECTOR,
|
||||
tablestore.FieldType.VECTOR,
|
||||
vector_options=tablestore.VectorOptions(
|
||||
data_type=tablestore.VectorDataType.VD_FLOAT_32,
|
||||
dimension=dimension,
|
||||
metric_type=tablestore.VectorMetricType.VM_COSINE,
|
||||
),
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.METADATA_KEY,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
self._tags_field,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
is_array=True,
|
||||
),
|
||||
]
|
||||
|
||||
index_meta = tablestore.SearchIndexMeta(field_schemas)
|
||||
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
|
||||
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_table_if_exist(self):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
for resp_tuple in search_index_list:
|
||||
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
self._tablestore_client.delete_table(self._table_name)
|
||||
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_search_index(self):
|
||||
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
def _write_row(self, primary_key: str, attributes: dict[str, Any]):
|
||||
pk = [("id", primary_key)]
|
||||
|
||||
tags = []
|
||||
for key, value in attributes[Field.METADATA_KEY].items():
|
||||
tags.append(str(key) + "=" + str(value))
|
||||
|
||||
attribute_columns = [
|
||||
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
|
||||
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
|
||||
(
|
||||
Field.METADATA_KEY,
|
||||
json.dumps(attributes[Field.METADATA_KEY]),
|
||||
),
|
||||
(self._tags_field, json.dumps(tags)),
|
||||
]
|
||||
row = tablestore.Row(pk, attribute_columns)
|
||||
self._tablestore_client.put_row(self._table_name, row)
|
||||
|
||||
def _delete_row(self, id: str):
|
||||
primary_key = [("id", id)]
|
||||
row = tablestore.Row(primary_key)
|
||||
self._tablestore_client.delete_row(self._table_name, row, None)
|
||||
|
||||
def _search_by_metadata(self, key: str, value: str) -> list[str]:
|
||||
query = tablestore.SearchQuery(
|
||||
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
|
||||
limit=1000,
|
||||
get_total_count=False,
|
||||
)
|
||||
rows: list[str] = []
|
||||
next_token = None
|
||||
while True:
|
||||
if next_token is not None:
|
||||
query.next_token = next_token
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(
|
||||
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
),
|
||||
)
|
||||
|
||||
if search_response is not None:
|
||||
rows.extend([row[0][0][1] for row in list(search_response.rows)])
|
||||
|
||||
if search_response is None or search_response.next_token == b"":
|
||||
break
|
||||
else:
|
||||
next_token = search_response.next_token
|
||||
|
||||
return rows
|
||||
|
||||
def _search_by_vector(
|
||||
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
knn_vector_query = tablestore.KnnVectorQuery(
|
||||
field_name=Field.VECTOR,
|
||||
top_k=top_k,
|
||||
float32_query_vector=query_vector,
|
||||
)
|
||||
if document_ids_filter:
|
||||
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
|
||||
|
||||
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
|
||||
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=search_query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
metadata["score"] = search_hit.score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
|
||||
"""
|
||||
Args:
|
||||
score: BM25 search score.
|
||||
k: decay factor, the larger the k, the steeper the low score end
|
||||
"""
|
||||
normalized_score = 1 - math.exp(-k * score)
|
||||
return max(0.0, min(1.0, normalized_score))
|
||||
|
||||
def _search_by_full_text(
|
||||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
|
||||
|
||||
if document_ids_filter:
|
||||
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||
|
||||
search_query = tablestore.SearchQuery(
|
||||
query=bool_query,
|
||||
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
|
||||
limit=top_k,
|
||||
)
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=search_query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
score = None
|
||||
if self._normalize_full_text_bm25_score:
|
||||
score = self._normalize_score_exp_decay(search_hit.score)
|
||||
|
||||
# skip when score is below threshold and use normalize score
|
||||
if score and score <= score_threshold:
|
||||
continue
|
||||
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
metadata["score"] = score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if self._normalize_full_text_bm25_score:
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
|
||||
class TableStoreVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TABLESTORE, collection_name))
|
||||
|
||||
return TableStoreVector(
|
||||
collection_name=collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=dify_config.TABLESTORE_ENDPOINT,
|
||||
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
||||
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
||||
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
||||
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
|
||||
),
|
||||
)
|
||||
@ -135,10 +135,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
@ -171,10 +167,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case VectorType.TABLESTORE:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
@ -183,10 +175,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from enum import StrEnum
|
||||
|
||||
class VectorType(StrEnum):
|
||||
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
MYSCALE = "myscale"
|
||||
@ -29,9 +28,7 @@ class VectorType(StrEnum):
|
||||
OCEANBASE = "oceanbase"
|
||||
SEEKDB = "seekdb"
|
||||
OPENGAUSS = "opengauss"
|
||||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
IRIS = "iris"
|
||||
HOLOGRES = "hologres"
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Final
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
|
||||
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
|
||||
@ -245,6 +245,9 @@ _END_STATE = frozenset(
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@ -256,9 +256,13 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from dify_graph.variables.segments import Segment
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: SegmentType):
|
||||
def get_zero_value(t: SegmentType) -> Segment:
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
class SupportsIncludeRouter(Protocol):
|
||||
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
|
||||
@ -69,19 +69,6 @@ class Storage:
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
@ -1,527 +0,0 @@
|
||||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: str | None = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# Temporarily disable permission check feature, set directly to false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str):
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
Path(target_filepath).write_bytes(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0 if rows else False
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
if rows:
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
@ -1,518 +0,0 @@
|
||||
"""ClickZetta Volume file lifecycle management
|
||||
|
||||
This module provides file lifecycle management features including version control,
|
||||
automatic cleanup, backup and restore.
|
||||
Supports complete lifecycle management for knowledge base files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
|
||||
ACTIVE = auto() # Active status
|
||||
ARCHIVED = auto() # Archived
|
||||
DELETED = auto() # Deleted (soft delete)
|
||||
BACKUP = auto() # Backup file
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""File metadata"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: str | None = None
|
||||
tags: dict[str, str] | None = None
|
||||
parent_version: int | None = None
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary format"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> FileMetadata:
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""File lifecycle manager"""
|
||||
|
||||
def __init__(self, storage, dataset_id: str | None = None):
|
||||
"""Initialize lifecycle manager
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume storage instance
|
||||
dataset_id: Dataset ID (for Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# Get permission manager (if exists)
|
||||
self._permission_manager: Any | None = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata:
|
||||
"""Save file and manage lifecycle
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
data: File content
|
||||
tags: File tags
|
||||
|
||||
Returns:
|
||||
File metadata
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Check if old version exists
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. If old version exists, create version backup
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. Calculate file information
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. Save new file
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. Create metadata
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# If created_at is string, convert to datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. Update metadata
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> FileMetadata | None:
|
||||
"""Get file metadata
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File metadata, returns None if not exists
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""List all versions of a file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File version list, sorted by version number
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# Get current version
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# Get historical versions
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# Parse version number
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
_ = int(version_str)
|
||||
# Simplified processing here, should actually read metadata from version file
|
||||
# Temporarily create basic metadata information
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception:
|
||||
# If cannot scan version files, only return current version
|
||||
logger.exception("Failed to scan version files for %s", filename)
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""Restore file to specified version
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
version: Version number to restore
|
||||
|
||||
Returns:
|
||||
Whether restore succeeded
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# Check if version file exists
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# Read version file content
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# Save current version as backup
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# Restore file
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""Archive file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether archive succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Update file status to archived
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""Soft delete file (move to deleted directory)
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether delete succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# Read file content
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# Move to deleted directory
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# Delete original file
|
||||
self._storage.delete(filename)
|
||||
|
||||
# Update metadata
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""Cleanup old version files
|
||||
|
||||
Args:
|
||||
max_versions: Maximum number of versions to keep
|
||||
max_age_days: Maximum retention days for version files
|
||||
|
||||
Returns:
|
||||
Number of files cleaned
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
|
||||
# Get all version files
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# Group by file
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# Parse filename and version
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Cleanup old versions for each file
|
||||
for base_filename, versions in file_versions.items():
|
||||
# Sort by version number
|
||||
versions.sort(key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
# Keep the newest max_versions versions, delete the rest
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""Get storage statistics
|
||||
|
||||
Returns:
|
||||
Storage statistics dictionary
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# Count file status
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# Count size
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# Count versions
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# Find newest and oldest files
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# Save as version file
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""Load metadata file"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""Calculate file checksum"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""Check file operation permission
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# If no permission manager, allow by default
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Map operation type to permission
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # Archive requires delete permission
|
||||
"restore": "save", # Restore requires write permission
|
||||
"cleanup": "delete", # Cleanup requires delete permission
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# Check permission
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# Safe default: deny access when permission check fails
|
||||
return False
|
||||
@ -1,649 +0,0 @@
|
||||
"""ClickZetta Volume permission management mechanism
|
||||
|
||||
This module provides Volume permission checking, validation and management features.
|
||||
According to ClickZetta's permission model, different Volume types have different permission requirements.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(StrEnum):
|
||||
"""Volume permission type enumeration"""
|
||||
|
||||
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
|
||||
LIST = "SELECT" # Listing files requires SELECT permission
|
||||
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
|
||||
USAGE = "USAGE" # Basic permission required for External Volume
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume permission manager"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
|
||||
"""Initialize permission manager
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta connection object or configuration dictionary
|
||||
volume_type: Volume type (user|table|external)
|
||||
volume_name: Volume name (for external volume)
|
||||
"""
|
||||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# Use connection object directly
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # Will get current username from connection
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool:
|
||||
"""Check if user has permission to perform specific operation
|
||||
|
||||
Args:
|
||||
operation: Type of operation to perform
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check User Volume permission
|
||||
|
||||
User Volume permission rules:
|
||||
- User has full permissions on their own User Volume
|
||||
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
|
||||
- Focus more on connection authentication rather than complex permission checking
|
||||
"""
|
||||
try:
|
||||
# Get current username
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check basic connection status
|
||||
with self._connection.cursor() as cursor:
|
||||
# Simple connection test, if query can be executed user has basic permissions
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# For User Volume, if permission check fails, it might be a configuration issue,
|
||||
# provide friendlier error message
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool:
|
||||
"""Check Table Volume permission
|
||||
|
||||
Table Volume permission rules:
|
||||
- Table Volume permissions inherit from corresponding table permissions
|
||||
- SELECT permission -> can READ/LIST files
|
||||
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# Check table permissions
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check External Volume permission
|
||||
|
||||
External Volume permission rules:
|
||||
- Try to get permissions for External Volume
|
||||
- If permission check fails, perform fallback verification
|
||||
- For development environment, provide more lenient permission checking
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check External Volume permissions
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume permission mapping: determine required permissions based on operation type
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# If permission check fails, try fallback verification
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# Fallback verification: try listing Volume to verify basic access permissions
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""Get user permissions for specified table
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this table
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find permissions for this table
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# Check if it's permission for this table
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# If no explicit permissions found, try executing a simple query to verify permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""Get current username"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""Get user's basic permission set"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find user's basic permissions
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
_ = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# Collect all relevant permissions
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""Get user permissions for specified External Volume
|
||||
|
||||
Args:
|
||||
volume_name: External Volume name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this Volume
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check Volume permissions
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# Parse permission results
|
||||
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# Check if it's permission for this Volume or hierarchical permission
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# If no explicit permissions found, try viewing Volume list to verify basic permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # At least has read permission
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# When permission check fails, try basic Volume access verification
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # Assume has write permission
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# Last fallback: assume basic permissions
|
||||
permissions.add("read")
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""Clear permission cache"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
@property
|
||||
def volume_type(self) -> str | None:
|
||||
"""Get the volume type."""
|
||||
return self._volume_type
|
||||
|
||||
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
|
||||
"""Get permission summary
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
Permission summary dictionary
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""Check permission inheritance for file path
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
operation: Operation to perform
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse file path
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# For Table Volume, first layer is dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# Check permissions for dataset
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# Check path traversal attack
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# Check if accessing sensitive directory
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume permission inheritance
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check if attempting to access other user's directory
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# Check basic permissions
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume permission inheritance
|
||||
# Check permissions for External Volume
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""Check if path contains path traversal attack"""
|
||||
# Check common path traversal patterns
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# Check absolute path
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# Check Windows drive path
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""Check if path is sensitive path"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool:
|
||||
"""Validate operation permission
|
||||
|
||||
Args:
|
||||
operation: Operation name (save|load|exists|delete|scan)
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume permission error exception"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
|
||||
"""Permission check decorator function
|
||||
|
||||
Args:
|
||||
permission_manager: Permission manager
|
||||
operation: Operation name
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: If no permission
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager.volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
@ -5,7 +5,6 @@ class StorageType(StrEnum):
|
||||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
|
||||
@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
@ -296,13 +296,11 @@ def segment_to_variable(
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
return variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value_type=segment.value_type,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
)
|
||||
|
||||
@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stream_with_request_context(response: object) -> Any:
|
||||
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||
return cast(Any, stream_with_context)(response)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
def compact_generate_response(
|
||||
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
def generate() -> Generator[str, None, None]:
|
||||
yield from stream_response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
def length_prefixed_response(
|
||||
magic_number: int,
|
||||
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[bytes, None, None]:
|
||||
for chunk in stream_response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
check_csrf_token(request, user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
def cached_import(module_path: str, class_name: str) -> Any:
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = sys.modules.get(module_path)
|
||||
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||
if module is None or getattr(spec, "_initializing", False):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str):
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
@ -1,7 +1,48 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
sub: str
|
||||
email: str
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,26 +52,38 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
class OAuth:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return _json_object(response)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
|
||||
@ -1,25 +1,57 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
class NotionSourceInfo(TypedDict):
|
||||
workspace_name: str | None
|
||||
workspace_icon: str | None
|
||||
workspace_id: str | None
|
||||
pages: list[NotionPageSummary]
|
||||
total: int
|
||||
|
||||
|
||||
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
def notion_workspace_name(self, access_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_source_info(
|
||||
*,
|
||||
workspace_name: str | None,
|
||||
workspace_icon: str | None,
|
||||
workspace_id: str | None,
|
||||
pages: list[NotionPageSummary],
|
||||
) -> NotionSourceInfo:
|
||||
return {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
|
||||
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
@ -89,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@ -200,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
|
||||
@ -19,21 +19,21 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
@ -405,7 +408,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@ -448,17 +451,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@ -536,11 +535,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@ -552,19 +547,20 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
@ -948,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@ -89,7 +89,6 @@ dependencies = [
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
]
|
||||
@ -101,6 +100,7 @@ packages = []
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb"]
|
||||
constraint-dependencies = ["cryptography>=46.0.5"]
|
||||
package = false
|
||||
|
||||
[dependency-groups]
|
||||
@ -202,11 +202,8 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.14.1",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.5.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
@ -218,7 +215,6 @@ vdb = [
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.1",
|
||||
"tcvectordb~=2.0.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
@ -46,11 +45,8 @@ core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
@ -63,7 +59,6 @@ core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
@ -138,8 +133,6 @@ dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
@ -147,8 +140,6 @@ extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
@ -156,19 +147,7 @@ extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
|
||||
@ -28,7 +28,6 @@
|
||||
"baidubce.auth.bce_credentials",
|
||||
"baidubce.bce_client_configuration",
|
||||
"baidubce.services.bos.bos_client",
|
||||
"clickzetta",
|
||||
"google.cloud",
|
||||
"obs",
|
||||
"qcloud_cos",
|
||||
@ -52,4 +51,4 @@
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||
id: str
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
status: WorkflowNodeExecutionStatus
|
||||
elapsed_time: float | None
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
execution_metadata: str | None
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
|
||||
@ -86,15 +86,6 @@ class OpsService:
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
|
||||
if tracing_provider == "weave" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
|
||||
|
||||
if tracing_provider == "aliyun" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
|
||||
@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@ -534,6 +534,13 @@ class PluginService:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credential_ids = session.scalars(
|
||||
select(ProviderCredential.id).where(
|
||||
|
||||
@ -1,168 +0,0 @@
|
||||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,49 +0,0 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def __init__(self, config_type: str):
|
||||
super().__init__()
|
||||
# Analyticdb requires collection_name length less than 60.
|
||||
# it's ok for normal usage.
|
||||
self.collection_name = self.collection_name.replace("_test", "")
|
||||
if config_type == "sql":
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=AnalyticdbVectorBySqlConfig(
|
||||
host="test_host",
|
||||
port=5432,
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
),
|
||||
api_config=None,
|
||||
)
|
||||
else:
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=None,
|
||||
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest("api").run_all_tests()
|
||||
AnalyticdbVectorTest("sql").run_all_tests()
|
||||
@ -1,25 +0,0 @@
|
||||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
@ -1,223 +0,0 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
@ -1,165 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@ -1,100 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import tablestore
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||
TableStoreConfig,
|
||||
TableStoreVector,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_document,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TableStoreVectorTest(AbstractVectorTest):
|
||||
def __init__(self, normalize_full_text_score: bool = False):
|
||||
super().__init__()
|
||||
self.vector = TableStoreVector(
|
||||
collection_name=self.collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=os.getenv("TABLESTORE_ENDPOINT"),
|
||||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
while True:
|
||||
search_response = self.vector._tablestore_client.search(
|
||||
table_name=self.vector._table_name,
|
||||
index_name=self.vector._index_name,
|
||||
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
if search_response.total_count == 1:
|
||||
break
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata["score"] > 0
|
||||
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
super().search_by_full_text()
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||
else:
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||
docs = self.vector.search_by_full_text(
|
||||
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||
)
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert len(docs) == 0
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def run_all_tests(self):
|
||||
try:
|
||||
self.vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_tablestore_vector(setup_mock_redis):
|
||||
TableStoreVectorTest().run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||
@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
|
||||
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(is_valid=False)
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
@ -743,6 +743,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id == "existing-cred"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
|
||||
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
|
||||
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
|
||||
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id is not None
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from core.ops.entities.config_entity import (
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
@ -23,7 +22,6 @@ class TestTracingProviderEnum:
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
@ -228,64 +226,6 @@ class TestOpikConfig:
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
@ -379,7 +319,6 @@ class TestConfigIntegration:
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
@ -0,0 +1,63 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
@ -0,0 +1,19 @@
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
|
||||
|
||||
def test_node_run_result_accepts_trigger_info_metadata() -> None:
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@ -58,7 +58,6 @@ class TestOpsService:
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
@ -88,7 +87,7 @@ class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
|
||||
5577
api/uv.lock
generated
5577
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -541,7 +541,7 @@ SUPABASE_URL=your-server-url
|
||||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`, `hologres`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`, `hologres`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -646,20 +646,6 @@ PGVECTO_RS_USER=postgres
|
||||
PGVECTO_RS_PASSWORD=difyai123456
|
||||
PGVECTO_RS_DATABASE=dify
|
||||
|
||||
# analyticdb configurations, only available when VECTOR_STORE is `analyticdb`
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# TiDB vector configurations, only available when VECTOR_STORE is `tidb_vector`
|
||||
TIDB_VECTOR_HOST=tidb
|
||||
TIDB_VECTOR_PORT=4000
|
||||
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -247,18 +247,6 @@ x-shared-env: &shared-api-worker-env
|
||||
PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres}
|
||||
PGVECTO_RS_PASSWORD: ${PGVECTO_RS_PASSWORD:-difyai123456}
|
||||
PGVECTO_RS_DATABASE: ${PGVECTO_RS_DATABASE:-dify}
|
||||
ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-your-ak}
|
||||
ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-your-sk}
|
||||
ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-cn-hangzhou}
|
||||
ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-gp-ab123456}
|
||||
ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-testaccount}
|
||||
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-testpassword}
|
||||
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-difypassword}
|
||||
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-gp-test.aliyuncs.com}
|
||||
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
|
||||
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
|
||||
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
|
||||
TIDB_VECTOR_HOST: ${TIDB_VECTOR_HOST:-tidb}
|
||||
TIDB_VECTOR_PORT: ${TIDB_VECTOR_PORT:-4000}
|
||||
TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-}
|
||||
@ -370,11 +358,6 @@ x-shared-env: &shared-api-worker-env
|
||||
HUAWEI_CLOUD_PASSWORD: ${HUAWEI_CLOUD_PASSWORD:-admin}
|
||||
UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io}
|
||||
UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify}
|
||||
TABLESTORE_ENDPOINT: ${TABLESTORE_ENDPOINT:-https://instance-name.cn-hangzhou.ots.aliyuncs.com}
|
||||
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
|
||||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||
CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-}
|
||||
CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-}
|
||||
CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-}
|
||||
@ -728,7 +711,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -770,7 +753,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -809,7 +792,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -839,7 +822,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC, JSX } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
@ -29,12 +29,11 @@ export type PopupProps = {
|
||||
langSmithConfig: LangSmithConfig | null
|
||||
langFuseConfig: LangFuseConfig | null
|
||||
opikConfig: OpikConfig | null
|
||||
weaveConfig: WeaveConfig | null
|
||||
aliyunConfig: AliyunConfig | null
|
||||
mlflowConfig: MLflowConfig | null
|
||||
databricksConfig: DatabricksConfig | null
|
||||
tencentConfig: TencentConfig | null
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | MLflowConfig | DatabricksConfig) => void
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | TencentConfig | MLflowConfig | DatabricksConfig) => void
|
||||
onConfigRemoved: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
@ -50,7 +49,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
langSmithConfig,
|
||||
langFuseConfig,
|
||||
opikConfig,
|
||||
weaveConfig,
|
||||
aliyunConfig,
|
||||
mlflowConfig,
|
||||
databricksConfig,
|
||||
@ -78,7 +76,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
}
|
||||
}, [onChooseProvider])
|
||||
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => {
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => {
|
||||
onConfigUpdated(currentProvider!, payload)
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigUpdated])
|
||||
@ -88,8 +86,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigRemoved])
|
||||
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && mlflowConfig && databricksConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !mlflowConfig && !databricksConfig && !tencentConfig
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && aliyunConfig && mlflowConfig && databricksConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !aliyunConfig && !mlflowConfig && !databricksConfig && !tencentConfig
|
||||
|
||||
const switchContent = (
|
||||
<Switch
|
||||
@ -164,19 +162,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
/>
|
||||
)
|
||||
|
||||
const weavePanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.weave}
|
||||
readOnly={readOnly}
|
||||
config={weaveConfig}
|
||||
hasConfigured={!!weaveConfig}
|
||||
onConfig={handleOnConfig(TracingProvider.weave)}
|
||||
isChosen={chosenProvider === TracingProvider.weave}
|
||||
onChoose={handleOnChoose(TracingProvider.weave)}
|
||||
key="weave-provider-panel"
|
||||
/>
|
||||
)
|
||||
|
||||
const aliyunPanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.aliyun}
|
||||
@ -240,9 +225,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
if (opikConfig)
|
||||
configuredPanels.push(opikPanel)
|
||||
|
||||
if (weaveConfig)
|
||||
configuredPanels.push(weavePanel)
|
||||
|
||||
if (arizeConfig)
|
||||
configuredPanels.push(arizePanel)
|
||||
|
||||
@ -282,9 +264,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
if (!opikConfig)
|
||||
notConfiguredPanels.push(opikPanel)
|
||||
|
||||
if (!weaveConfig)
|
||||
notConfiguredPanels.push(weavePanel)
|
||||
|
||||
if (!aliyunConfig)
|
||||
notConfiguredPanels.push(aliyunPanel)
|
||||
|
||||
@ -319,7 +298,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
return aliyunConfig
|
||||
if (currentProvider === TracingProvider.tencent)
|
||||
return tencentConfig
|
||||
return weaveConfig
|
||||
return opikConfig
|
||||
}
|
||||
|
||||
return (
|
||||
@ -365,7 +344,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
{opikPanel}
|
||||
{mlflowPanel}
|
||||
{databricksPanel}
|
||||
{weavePanel}
|
||||
{arizePanel}
|
||||
{phoenixPanel}
|
||||
{aliyunPanel}
|
||||
|
||||
@ -6,7 +6,6 @@ export const docURL = {
|
||||
[TracingProvider.langSmith]: 'https://docs.smith.langchain.com/',
|
||||
[TracingProvider.langfuse]: 'https://docs.langfuse.com',
|
||||
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/integrations/dify',
|
||||
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
|
||||
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
|
||||
[TracingProvider.mlflow]: 'https://mlflow.org/docs/latest/genai/',
|
||||
[TracingProvider.databricks]: 'https://docs.databricks.com/aws/en/mlflow3/genai/tracing/',
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import type { TracingStatus } from '@/models/app'
|
||||
import {
|
||||
RiArrowDownDoubleLine,
|
||||
@ -12,7 +12,7 @@ import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
@ -70,7 +70,6 @@ const Panel: FC = () => {
|
||||
[TracingProvider.langSmith]: LangsmithIcon,
|
||||
[TracingProvider.langfuse]: LangfuseIcon,
|
||||
[TracingProvider.opik]: OpikIcon,
|
||||
[TracingProvider.weave]: WeaveIcon,
|
||||
[TracingProvider.aliyun]: AliyunIcon,
|
||||
[TracingProvider.mlflow]: MlflowIcon,
|
||||
[TracingProvider.databricks]: DatabricksIcon,
|
||||
@ -83,12 +82,11 @@ const Panel: FC = () => {
|
||||
const [langSmithConfig, setLangSmithConfig] = useState<LangSmithConfig | null>(null)
|
||||
const [langFuseConfig, setLangFuseConfig] = useState<LangFuseConfig | null>(null)
|
||||
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
|
||||
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
|
||||
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
|
||||
const [mlflowConfig, setMLflowConfig] = useState<MLflowConfig | null>(null)
|
||||
const [databricksConfig, setDatabricksConfig] = useState<DatabricksConfig | null>(null)
|
||||
const [tencentConfig, setTencentConfig] = useState<TencentConfig | null>(null)
|
||||
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig || mlflowConfig || databricksConfig || tencentConfig)
|
||||
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || arizeConfig || phoenixConfig || aliyunConfig || mlflowConfig || databricksConfig || tencentConfig)
|
||||
|
||||
const fetchTracingConfig = async () => {
|
||||
const getArizeConfig = async () => {
|
||||
@ -116,11 +114,6 @@ const Panel: FC = () => {
|
||||
if (!OpikHasNotConfig)
|
||||
setOpikConfig(opikConfig as OpikConfig)
|
||||
}
|
||||
const getWeaveConfig = async () => {
|
||||
const { tracing_config: weaveConfig, has_not_configured: weaveHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.weave })
|
||||
if (!weaveHasNotConfig)
|
||||
setWeaveConfig(weaveConfig as WeaveConfig)
|
||||
}
|
||||
const getAliyunConfig = async () => {
|
||||
const { tracing_config: aliyunConfig, has_not_configured: aliyunHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.aliyun })
|
||||
if (!aliyunHasNotConfig)
|
||||
@ -147,7 +140,6 @@ const Panel: FC = () => {
|
||||
getLangSmithConfig(),
|
||||
getLangFuseConfig(),
|
||||
getOpikConfig(),
|
||||
getWeaveConfig(),
|
||||
getAliyunConfig(),
|
||||
getMLflowConfig(),
|
||||
getDatabricksConfig(),
|
||||
@ -168,8 +160,6 @@ const Panel: FC = () => {
|
||||
setLangFuseConfig(tracing_config as LangFuseConfig)
|
||||
else if (provider === TracingProvider.opik)
|
||||
setOpikConfig(tracing_config as OpikConfig)
|
||||
else if (provider === TracingProvider.weave)
|
||||
setWeaveConfig(tracing_config as WeaveConfig)
|
||||
else if (provider === TracingProvider.aliyun)
|
||||
setAliyunConfig(tracing_config as AliyunConfig)
|
||||
else if (provider === TracingProvider.tencent)
|
||||
@ -187,8 +177,6 @@ const Panel: FC = () => {
|
||||
setLangFuseConfig(null)
|
||||
else if (provider === TracingProvider.opik)
|
||||
setOpikConfig(null)
|
||||
else if (provider === TracingProvider.weave)
|
||||
setWeaveConfig(null)
|
||||
else if (provider === TracingProvider.aliyun)
|
||||
setAliyunConfig(null)
|
||||
else if (provider === TracingProvider.mlflow)
|
||||
@ -240,7 +228,6 @@ const Panel: FC = () => {
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
mlflowConfig={mlflowConfig}
|
||||
databricksConfig={databricksConfig}
|
||||
@ -279,7 +266,6 @@ const Panel: FC = () => {
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
mlflowConfig={mlflowConfig}
|
||||
databricksConfig={databricksConfig}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
@ -23,10 +23,10 @@ import { TracingProvider } from './type'
|
||||
type Props = {
|
||||
appId: string
|
||||
type: TracingProvider
|
||||
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig | null
|
||||
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig | null
|
||||
onRemoved: () => void
|
||||
onCancel: () => void
|
||||
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => void
|
||||
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => void
|
||||
onChosen: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
@ -64,14 +64,6 @@ const opikConfigTemplate = {
|
||||
workspace: '',
|
||||
}
|
||||
|
||||
const weaveConfigTemplate = {
|
||||
api_key: '',
|
||||
entity: '',
|
||||
project: '',
|
||||
endpoint: '',
|
||||
host: '',
|
||||
}
|
||||
|
||||
const aliyunConfigTemplate = {
|
||||
app_name: '',
|
||||
license_key: '',
|
||||
@ -112,7 +104,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
const isEdit = !!payload
|
||||
const isAdd = !isEdit
|
||||
const [isSaving, setIsSaving] = useState(false)
|
||||
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig>((() => {
|
||||
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig>((() => {
|
||||
if (isEdit)
|
||||
return payload
|
||||
|
||||
@ -143,7 +135,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
else if (type === TracingProvider.tencent)
|
||||
return tencentConfigTemplate
|
||||
|
||||
return weaveConfigTemplate
|
||||
return opikConfigTemplate
|
||||
})())
|
||||
const [isShowRemoveConfirm, {
|
||||
setTrue: showRemoveConfirm,
|
||||
@ -215,14 +207,6 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
// const postData = config as OpikConfig
|
||||
}
|
||||
|
||||
if (type === TracingProvider.weave) {
|
||||
const postData = config as WeaveConfig
|
||||
if (!errorMessage && !postData.api_key)
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.aliyun) {
|
||||
const postData = config as AliyunConfig
|
||||
if (!errorMessage && !postData.app_name)
|
||||
@ -424,47 +408,6 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.weave && (
|
||||
<>
|
||||
<Field
|
||||
label="API Key"
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as WeaveConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as WeaveConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Entity"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).entity}
|
||||
onChange={handleConfigChange('entity')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Entity' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).endpoint}
|
||||
onChange={handleConfigChange('endpoint')}
|
||||
placeholder="https://trace.wandb.ai/"
|
||||
/>
|
||||
<Field
|
||||
label="Host"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).host}
|
||||
onChange={handleConfigChange('host')}
|
||||
placeholder="https://api.wandb.ai"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.langSmith && (
|
||||
<>
|
||||
<Field
|
||||
|
||||
@ -6,7 +6,7 @@ import {
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { TracingProvider } from './type'
|
||||
@ -30,7 +30,6 @@ const getIcon = (type: TracingProvider) => {
|
||||
[TracingProvider.langSmith]: LangsmithIconBig,
|
||||
[TracingProvider.langfuse]: LangfuseIconBig,
|
||||
[TracingProvider.opik]: OpikIconBig,
|
||||
[TracingProvider.weave]: WeaveIconBig,
|
||||
[TracingProvider.aliyun]: AliyunIconBig,
|
||||
[TracingProvider.mlflow]: MlflowIconBig,
|
||||
[TracingProvider.databricks]: DatabricksIconBig,
|
||||
|
||||
@ -4,7 +4,6 @@ export enum TracingProvider {
|
||||
langSmith = 'langsmith',
|
||||
langfuse = 'langfuse',
|
||||
opik = 'opik',
|
||||
weave = 'weave',
|
||||
aliyun = 'aliyun',
|
||||
mlflow = 'mlflow',
|
||||
databricks = 'databricks',
|
||||
@ -42,15 +41,6 @@ export type OpikConfig = {
|
||||
workspace: string
|
||||
url: string
|
||||
}
|
||||
|
||||
export type WeaveConfig = {
|
||||
api_key: string
|
||||
entity: string
|
||||
project: string
|
||||
endpoint: string
|
||||
host: string
|
||||
}
|
||||
|
||||
export type AliyunConfig = {
|
||||
app_name: string
|
||||
license_key: string
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -1,20 +0,0 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './WeaveIcon.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'WeaveIcon'
|
||||
|
||||
export default Icon
|
||||
File diff suppressed because one or more lines are too long
@ -1,20 +0,0 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './WeaveIconBig.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'WeaveIconBig'
|
||||
|
||||
export default Icon
|
||||
@ -17,5 +17,3 @@ export { default as PhoenixIconBig } from './PhoenixIconBig'
|
||||
export { default as TencentIcon } from './TencentIcon'
|
||||
export { default as TencentIconBig } from './TencentIconBig'
|
||||
export { default as TracingIcon } from './TracingIcon'
|
||||
export { default as WeaveIcon } from './WeaveIcon'
|
||||
export { default as WeaveIconBig } from './WeaveIconBig'
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "تتبع",
|
||||
"tracing.tracingDescription": "التقاط السياق الكامل لتنفيذ التطبيق، بما في ذلك مكالمات LLM، والسياق، والمطالبات، وطلبات HTTP، والمزيد، إلى منصة تتبع تابعة لجهة خارجية.",
|
||||
"tracing.view": "عرض",
|
||||
"tracing.weave.description": "Weave هي منصة مفتوحة المصدر لتقييم واختبار ومراقبة تطبيقات LLM.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "كل الأنواع",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Nachverfolgung",
|
||||
"tracing.tracingDescription": "Erfassung des vollständigen Kontexts der Anwendungsausführung, einschließlich LLM-Aufrufe, Kontext, Prompts, HTTP-Anfragen und mehr, auf einer Nachverfolgungsplattform von Drittanbietern.",
|
||||
"tracing.view": "Ansehen",
|
||||
"tracing.weave.description": "Weave ist eine Open-Source-Plattform zur Bewertung, Testung und Überwachung von LLM-Anwendungen.",
|
||||
"tracing.weave.title": "Weben",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "ALLE Typen",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Tracing",
|
||||
"tracing.tracingDescription": "Capture the full context of app execution, including LLM calls, context, prompts, HTTP requests, and more, to a third-party tracing platform.",
|
||||
"tracing.view": "View",
|
||||
"tracing.weave.description": "Weave is an open-source platform for evaluating, testing, and monitoring LLM applications.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "All Types ",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Rastreo",
|
||||
"tracing.tracingDescription": "Captura el contexto completo de la ejecución de la app, incluyendo llamadas LLM, contexto, prompts, solicitudes HTTP y más, en una plataforma de rastreo de terceros.",
|
||||
"tracing.view": "Vista",
|
||||
"tracing.weave.description": "Weave es una plataforma de código abierto para evaluar, probar y monitorear aplicaciones de LLM.",
|
||||
"tracing.weave.title": "Tejer",
|
||||
"typeSelector.advanced": "Flujo de chat",
|
||||
"typeSelector.agent": "Agente",
|
||||
"typeSelector.all": "Todos los tipos",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "ردیابی",
|
||||
"tracing.tracingDescription": "ثبت کامل متن اجرای برنامه، از جمله تماسهای LLM، متن، درخواستهای HTTP و بیشتر، به یک پلتفرم ردیابی شخص ثالث.",
|
||||
"tracing.view": "مشاهده",
|
||||
"tracing.weave.description": "ویو یک پلتفرم متن باز برای ارزیابی، آزمایش و نظارت بر برنامههای LLM است.",
|
||||
"tracing.weave.title": "بافندگی",
|
||||
"typeSelector.advanced": "چتفلو",
|
||||
"typeSelector.agent": "نماینده",
|
||||
"typeSelector.all": "همه انواع",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Traçage",
|
||||
"tracing.tracingDescription": "Capturez le contexte complet de l'exécution de l'application, y compris les appels LLM, le contexte, les prompts, les requêtes HTTP et plus encore, vers une plateforme de traçage tierce.",
|
||||
"tracing.view": "Vue",
|
||||
"tracing.weave.description": "Weave est une plateforme open-source pour évaluer, tester et surveiller les applications LLM.",
|
||||
"tracing.weave.title": "Tisser",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "Tous Types",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "ट्रेसिंग",
|
||||
"tracing.tracingDescription": "एप्लिकेशन निष्पादन का पूरा संदर्भ कैप्चर करें, जिसमें LLM कॉल, संदर्भ, प्रॉम्प्ट्स, HTTP अनुरोध और अधिक शामिल हैं, एक तृतीय-पक्ष ट्रेसिंग प्लेटफ़ॉर्म पर।",
|
||||
"tracing.view": "देखना",
|
||||
"tracing.weave.description": "वीव एक ओपन-सोर्स प्लेटफ़ॉर्म है जो LLM अनुप्रयोगों का मूल्यांकन, परीक्षण और निगरानी करने के लिए है।",
|
||||
"tracing.weave.title": "बुनना",
|
||||
"typeSelector.advanced": "चैटफ्लो",
|
||||
"typeSelector.agent": "एजेंट",
|
||||
"typeSelector.all": "सभी प्रकार",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Menelusuri",
|
||||
"tracing.tracingDescription": "Tangkap konteks lengkap eksekusi aplikasi, termasuk panggilan LLM, konteks, perintah, permintaan HTTP, dan lainnya, ke platform pelacakan pihak ketiga.",
|
||||
"tracing.view": "Melihat",
|
||||
"tracing.weave.description": "Weave adalah platform sumber terbuka untuk mengevaluasi, menguji, dan memantau aplikasi LLM.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Alur obrolan",
|
||||
"typeSelector.agent": "Agen",
|
||||
"typeSelector.all": "Semua Jenis",
|
||||
|
||||
@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Tracciamento",
|
||||
"tracing.tracingDescription": "Cattura il contesto completo dell'esecuzione dell'app, incluse chiamate LLM, contesto, prompt, richieste HTTP e altro, su una piattaforma di tracciamento di terze parti.",
|
||||
"tracing.view": "Vista",
|
||||
"tracing.weave.description": "Weave è una piattaforma open-source per valutare, testare e monitorare le applicazioni LLM.",
|
||||
"tracing.weave.title": "Intrecciare",
|
||||
"typeSelector.advanced": "Flusso di chat",
|
||||
"typeSelector.agent": "Agente",
|
||||
"typeSelector.all": "TUTTI I Tipi",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user