feat: add category for plugins

This commit is contained in:
Yeuoly
2024-10-16 13:03:50 +08:00
parent 276701e1b7
commit a81293cf5a
5 changed files with 48 additions and 23 deletions

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
@ -54,6 +54,12 @@ class PluginResourceRequirements(BaseModel):
permission: Optional[Permission]
class PluginCategory(str, Enum):
Tool = "tool"
Model = "model"
Extension = "extension"
class PluginDeclaration(BaseModel):
class Plugins(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list)
@ -65,6 +71,7 @@ class PluginDeclaration(BaseModel):
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
icon: str
label: I18nObject
category: PluginCategory
created_at: datetime.datetime
resource: PluginResourceRequirements
plugins: Plugins
@ -72,6 +79,18 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
@model_validator(mode="before")
@classmethod
def validate_category(cls, values: dict) -> dict:
# auto detect category
if values.get("tool"):
values["category"] = PluginCategory.Tool
elif values.get("model"):
values["category"] = PluginCategory.Model
else:
values["category"] = PluginCategory.Extension
return values
class PluginEntity(BasePluginEntity):
name: str

View File

@ -128,3 +128,8 @@ class PluginInstallTask(BasePluginEntity):
total_plugins: int = Field(description="The total number of plugins to be installed.")
completed_plugins: int = Field(description="The number of plugins that have been installed.")
plugins: list[PluginInstallTaskPluginStatus] = Field(description="The status of the plugins.")
class PluginInstallTaskStartResponse(BasePluginEntity):
all_installed: bool = Field(description="Whether all plugins are installed.")
task_id: str = Field(description="The ID of the install task.")