mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -4,15 +4,15 @@ import logging
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
from models.engine import db
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
def migrate(cls):
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
@ -26,7 +26,7 @@ class PluginDataMigration:
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
def migrate_datasets(cls):
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
@ -46,7 +46,11 @@ limit 1000"""
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
print(type(retrieval_model))
|
||||
logger.debug(
|
||||
"Processing dataset %s with retrieval model of type %s",
|
||||
record_id,
|
||||
type(retrieval_model),
|
||||
)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
@ -126,9 +130,7 @@ limit 1000"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(
|
||||
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||
) -> None:
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
@ -175,7 +177,7 @@ limit 1000"""
|
||||
# update jina to langgenius/jina_tool/jina etc.
|
||||
updated_value = provider_cls(provider_name).to_string()
|
||||
batch_updates.append((updated_value, record_id))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import re
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||
|
||||
# Compile regex pattern for version extraction at module level for better performance
|
||||
_VERSION_REGEX = re.compile(r":(?P<version>[0-9]+(?:\.[0-9]+){2}(?:[+-][0-9A-Za-z.-]+)?)(?:@|$)")
|
||||
|
||||
|
||||
class DependenciesAnalysisService:
|
||||
@ -48,6 +54,13 @@ class DependenciesAnalysisService:
|
||||
for dependency in dependencies:
|
||||
unique_identifier = dependency.value.plugin_unique_identifier
|
||||
if unique_identifier in missing_plugin_unique_identifiers:
|
||||
# Extract version for Marketplace dependencies
|
||||
if dependency.type == PluginDependency.Type.Marketplace:
|
||||
version_match = _VERSION_REGEX.search(unique_identifier)
|
||||
if version_match:
|
||||
dependency.value.version = version_match.group("version")
|
||||
|
||||
# Create and append the dependency (same for all types)
|
||||
leaked_dependencies.append(
|
||||
PluginDependency(
|
||||
type=dependency.type,
|
||||
|
||||
@ -11,7 +11,14 @@ class OAuthProxyService(BasePluginClient):
|
||||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}):
|
||||
def create_proxy_context(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
extra_data: dict = {},
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
@ -32,6 +39,8 @@ class OAuthProxyService(BasePluginClient):
|
||||
"tenant_id": tenant_id,
|
||||
"provider": provider,
|
||||
}
|
||||
if credential_id:
|
||||
data["credential_id"] = credential_id
|
||||
redis_client.setex(
|
||||
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
|
||||
OAuthProxyService.__MAX_AGE__,
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
@ -16,15 +16,17 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||
from models.tools import BuiltinToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +35,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
@ -55,7 +57,7 @@ class PluginMigration:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
@ -99,6 +101,7 @@ class PluginMigration:
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
@ -255,7 +258,7 @@ class PluginMigration:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
@ -280,7 +283,7 @@ class PluginMigration:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
@ -291,7 +294,7 @@ class PluginMigration:
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
@ -328,7 +331,7 @@ class PluginMigration:
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
@ -348,7 +351,7 @@ class PluginMigration:
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
def install(tenant_id: str, plugin_ids: list[str]):
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
@ -420,6 +423,94 @@ class PluginMigration:
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
"""
|
||||
Install rag pipeline plugins.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
|
||||
plugins = cls.extract_unique_plugins(extracted_plugins)
|
||||
plugin_install_failed = []
|
||||
|
||||
# use a fake tenant id to install all the plugins
|
||||
fake_tenant_id = uuid4().hex
|
||||
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(
|
||||
tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int
|
||||
) -> None:
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
try:
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
# at most 64 plugins one batch
|
||||
for i in range(0, len(plugin_ids), 64):
|
||||
batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
|
||||
batch_plugin_identifiers = [
|
||||
plugin_ids[plugin_id]
|
||||
for plugin_id in batch_plugin_ids
|
||||
if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
|
||||
]
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, batch_plugin_identifiers)
|
||||
|
||||
total_success_tenant += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to install plugins for tenant %s", tenant_id)
|
||||
total_failed_tenant += 1
|
||||
|
||||
page = 1
|
||||
total_success_tenant = 0
|
||||
total_failed_tenant = 0
|
||||
while True:
|
||||
# paginate
|
||||
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
if tenants.items is None or len(tenants.items) == 0:
|
||||
break
|
||||
|
||||
for tenant in tenants:
|
||||
tenant_id = tenant.id
|
||||
# get plugin unique identifier
|
||||
thread_pool.submit(
|
||||
install,
|
||||
tenant_id,
|
||||
plugins.get("plugins", {}),
|
||||
total_success_tenant,
|
||||
total_failed_tenant,
|
||||
)
|
||||
|
||||
page += 1
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
# uninstall all the plugins for fake tenant
|
||||
try:
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
while installation:
|
||||
for plugin in installation:
|
||||
manager.uninstall(fake_tenant_id, plugin.installation_id)
|
||||
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
|
||||
|
||||
Path(output_file).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"total_success_tenant": total_success_tenant,
|
||||
"total_failed_tenant": total_failed_tenant,
|
||||
"plugin_install_failed": plugin_install_failed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_plugin_instance_install(
|
||||
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
@ -12,7 +11,6 @@ from core.helper.download import download_with_size_limit
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
@ -28,6 +26,7 @@ from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
@ -47,11 +46,11 @@ class PluginService:
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
Fetch the latest plugin version
|
||||
"""
|
||||
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
|
||||
result: dict[str, PluginService.LatestPluginCache | None] = {}
|
||||
|
||||
try:
|
||||
cache_not_exists = []
|
||||
@ -110,7 +109,7 @@ class PluginService:
|
||||
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
|
||||
|
||||
@staticmethod
|
||||
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
|
||||
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
|
||||
"""
|
||||
Check the plugin installation scope
|
||||
"""
|
||||
@ -145,7 +144,7 @@ class PluginService:
|
||||
return manager.get_debugging_key(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
|
||||
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
List the latest versions of the plugins
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user