mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 13:17:43 +08:00
Compare commits
107 Commits
feat/dify-
...
feat/rbac
| Author | SHA1 | Date | |
|---|---|---|---|
| e2153a609c | |||
| 9609f003c6 | |||
| a3b00a2f83 | |||
| 5086dfcf74 | |||
| b3572c646e | |||
| 97c406f0b4 | |||
| d878a29c43 | |||
| e44252c242 | |||
| 3aa8abbe14 | |||
| a409a0c3a1 | |||
| d90825fd8a | |||
| 3dcea78e10 | |||
| d1ad01339b | |||
| f7807c532d | |||
| 212252bb78 | |||
| 35696c6b2e | |||
| a4d12efbb6 | |||
| 248e14c323 | |||
| 7442e91ba7 | |||
| 1eb6446f5d | |||
| 435c8ec96c | |||
| 9fa1e69904 | |||
| c386908370 | |||
| 74cc6af59c | |||
| 84df9d2f01 | |||
| 03df3423fc | |||
| bce61ff7ce | |||
| 4f669f3278 | |||
| 217ab4d2c3 | |||
| bfee04be8c | |||
| 0e0a78c305 | |||
| deb97c4149 | |||
| ea47036a5d | |||
| 5ad05cf98a | |||
| 1a1d3b1137 | |||
| d0955ac5bb | |||
| 4c863d403c | |||
| 49dea08005 | |||
| b789f821ea | |||
| 7e093c3807 | |||
| 0e157667ad | |||
| e73f720505 | |||
| e22b03797c | |||
| 03644a73cc | |||
| 28ee67755c | |||
| 451cc7bcc8 | |||
| 1a4671bcf7 | |||
| dffb26e2e4 | |||
| e7f8a9fcfc | |||
| 3526afcfd9 | |||
| 554e57c906 | |||
| ffee3b45aa | |||
| 4f42753bd1 | |||
| 0519bc00cf | |||
| d02a36e68f | |||
| aa2dbe959a | |||
| ba5271dd76 | |||
| a329ff8777 | |||
| 55bcca2bc4 | |||
| 6c4f3e8584 | |||
| 944f705889 | |||
| d580b16303 | |||
| cd26e90ae5 | |||
| b9d98f6c54 | |||
| 093501ab54 | |||
| 5ed516d912 | |||
| 6382ffe823 | |||
| 98fddce3b9 | |||
| 1f2805b190 | |||
| fd75223fa8 | |||
| 8ed13671c7 | |||
| 7fef8ff766 | |||
| 34c1caec48 | |||
| 457b4a5e48 | |||
| df28c99817 | |||
| 775f9212f3 | |||
| 4481dd2ffa | |||
| 325b84c1eb | |||
| a3b938dd57 | |||
| 5081d52872 | |||
| 9616ac170b | |||
| 625bd7ab63 | |||
| da9d9cb28f | |||
| a6905b25cf | |||
| ddd546ef88 | |||
| 8cb56713c0 | |||
| ecd3dbcb6e | |||
| 9c6a123687 | |||
| e1d0addf41 | |||
| 34f1ed0ab7 | |||
| 7b97789fd2 | |||
| 5f4b086e39 | |||
| 339e4c8a1f | |||
| 12b93290fa | |||
| 5907b3f809 | |||
| b32ec8741e | |||
| 2c6f195362 | |||
| ba3a808b9f | |||
| 63e3267993 | |||
| c7d96badf4 | |||
| 32d75fe08c | |||
| 6583bcb746 | |||
| 7cc3db663a | |||
| 17e4fee6b2 | |||
| 2dc080c845 | |||
| 6a2bb145e3 | |||
| 73551495c5 |
@ -9,7 +9,6 @@ The codebase is split into:
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
- **Dify Agent Backend** (`/dify-agent`): Backend services for managing and executing agent
|
||||
|
||||
## Backend Workflow
|
||||
|
||||
|
||||
@ -23,6 +23,11 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
RBAC_ENABLED: bool = Field(
|
||||
description="Enable enterprise RBAC APIs. When disabled, compatibility responses fall back to legacy roles.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -132,6 +132,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
rbac,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -199,6 +200,7 @@ __all__ = [
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"rbac",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
|
||||
@ -27,6 +27,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
@ -37,6 +38,7 @@ from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
@ -330,6 +332,7 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
permission_keys: list[str] = Field(default_factory=list)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -475,6 +478,20 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
if app_pagination.items:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.AppPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
app_ids,
|
||||
)
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = permission_keys_map.get(str(app.id), [])
|
||||
else:
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = []
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
|
||||
@ -57,6 +57,7 @@ from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
@ -127,6 +128,14 @@ def _validate_doc_form(value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
def _ensure_permission_keys(dataset: Dataset, *, enabled: bool) -> None:
|
||||
if not enabled:
|
||||
setattr(dataset, "permission_keys", [])
|
||||
return
|
||||
if not isinstance(getattr(dataset, "permission_keys", None), list):
|
||||
setattr(dataset, "permission_keys", [])
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
@ -329,6 +338,19 @@ class DatasetListApi(Resource):
|
||||
query.include_all,
|
||||
)
|
||||
|
||||
for dataset in datasets:
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
|
||||
if dify_config.RBAC_ENABLED and datasets:
|
||||
dataset_ids = [str(dataset.id) for dataset in datasets]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.DatasetPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
dataset_ids,
|
||||
)
|
||||
for dataset in datasets:
|
||||
setattr(dataset, "permission_keys", permission_keys_map.get(str(dataset.id), []))
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
@ -410,6 +432,7 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@ -434,6 +457,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
@ -503,6 +527,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -72,6 +73,19 @@ register_enum_models(console_ns, TenantAccountRole)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
|
||||
|
||||
def _serialize_member_roles(current_role: str | None, member_roles: list[enterprise_rbac_service.MemberRoleSummary]) -> list[dict[str, str]]:
|
||||
if member_roles:
|
||||
return [{"id": role.id, "name": role.name} for role in member_roles]
|
||||
if current_role:
|
||||
return [{"id": current_role, "name": current_role}]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_enum_value(value: object) -> str:
|
||||
normalized = getattr(value, "value", value)
|
||||
return str(normalized) if normalized is not None else ""
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
class MemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
@ -85,7 +99,36 @@ class MemberListApi(Resource):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
member_ids = [member.id for member in members]
|
||||
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
|
||||
str(current_user.current_tenant.id),
|
||||
current_user.id,
|
||||
member_ids,
|
||||
)
|
||||
roles_map = {item.account_id: item.roles for item in member_roles}
|
||||
else:
|
||||
roles_map = {}
|
||||
|
||||
serialized_members = []
|
||||
for member in members:
|
||||
current_role = _normalize_enum_value(member.current_role)
|
||||
serialized_members.append(
|
||||
{
|
||||
"id": member.id,
|
||||
"name": member.name,
|
||||
"email": member.email,
|
||||
"avatar": member.avatar,
|
||||
"last_login_at": member.last_login_at,
|
||||
"last_active_at": member.last_active_at,
|
||||
"created_at": member.created_at,
|
||||
"role": current_role,
|
||||
"roles": _serialize_member_roles(current_role, roles_map.get(member.id, [])),
|
||||
"status": _normalize_enum_value(member.status),
|
||||
}
|
||||
)
|
||||
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(serialized_members)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
606
api/controllers/console/workspace/rbac.py
Normal file
606
api/controllers/console/workspace/rbac.py
Normal file
@ -0,0 +1,606 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.enterprise import rbac_service as svc
|
||||
|
||||
|
||||
_LEGACY_WORKSPACE_PERMISSION_KEYS: list[str] = [
|
||||
# These keys are copied from the enterprise RBAC catalog examples in
|
||||
# `dify-rbac.md` so the legacy workspace roles stay in the same key format
|
||||
# as the enterprise RBAC surface.
|
||||
"workspace.member.manage",
|
||||
"workspace.role.manage",
|
||||
]
|
||||
|
||||
_LEGACY_APP_PERMISSION_KEYS: list[str] = [
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
"app.acl.access_config",
|
||||
]
|
||||
|
||||
_LEGACY_DATASET_PERMISSION_KEYS: list[str] = [
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.use",
|
||||
]
|
||||
|
||||
_LEGACY_ROLE_PERMISSION_KEYS: dict[str, list[str]] = {
|
||||
# These legacy role groups predate the RBAC refactor. The mapping keeps the
|
||||
# old workspace roles readable through the new RBAC endpoint by translating
|
||||
# each role into the closest enterprise permission keys that already exist
|
||||
# in the catalog and tests.
|
||||
"owner": [
|
||||
*_LEGACY_WORKSPACE_PERMISSION_KEYS,
|
||||
*_LEGACY_APP_PERMISSION_KEYS,
|
||||
*_LEGACY_DATASET_PERMISSION_KEYS,
|
||||
],
|
||||
"admin": [
|
||||
*_LEGACY_WORKSPACE_PERMISSION_KEYS,
|
||||
*_LEGACY_APP_PERMISSION_KEYS,
|
||||
*_LEGACY_DATASET_PERMISSION_KEYS,
|
||||
],
|
||||
"editor": [
|
||||
*_LEGACY_APP_PERMISSION_KEYS,
|
||||
*_LEGACY_DATASET_PERMISSION_KEYS,
|
||||
],
|
||||
"normal": [
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
],
|
||||
"dataset_operator": [
|
||||
*_LEGACY_DATASET_PERMISSION_KEYS,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _current_ids() -> tuple[str, str]:
|
||||
"""Return ``(tenant_id, account_id)`` for the authenticated user, or
|
||||
raise a 404 when no tenant is associated with the session.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not tenant_id:
|
||||
raise NotFound("Current workspace not found")
|
||||
return tenant_id, user.id
|
||||
|
||||
|
||||
def _payload(model: type[BaseModel]) -> Any:
|
||||
"""Validate the JSON body against ``model`` or raise ``ValidationError``.
|
||||
|
||||
``ValidationError`` bubbles up as HTTP 400 thanks to
|
||||
``controllers/common/helpers.py`` error handling.
|
||||
"""
|
||||
try:
|
||||
return model.model_validate(console_ns.payload or {})
|
||||
except ValidationError as exc:
|
||||
# Re-raise as-is so the upstream error handler renders a 400.
|
||||
raise exc
|
||||
|
||||
|
||||
def _dump(model: BaseModel) -> dict[str, Any]:
|
||||
return model.model_dump(mode="json")
|
||||
|
||||
|
||||
class _PaginationQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
|
||||
results_per_page: int | None = Field(
|
||||
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
|
||||
)
|
||||
reverse: bool | None = None
|
||||
|
||||
def to_inner_options(self) -> svc.ListOption:
|
||||
return svc.ListOption.model_validate(self.model_dump())
|
||||
|
||||
|
||||
def _pagination_options() -> svc.ListOption:
|
||||
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
|
||||
|
||||
|
||||
def _legacy_workspace_roles(options: svc.ListOption | None = None) -> svc.Paginated[svc.RBACRole]:
|
||||
"""Return the built-in legacy workspace roles in the RBAC list shape.
|
||||
|
||||
This keeps the new `/rbac/roles` endpoint compatible with the original
|
||||
Dify role model when enterprise RBAC is disabled.
|
||||
"""
|
||||
|
||||
legacy_roles = [
|
||||
svc.RBACRole(
|
||||
id=role_name,
|
||||
tenant_id="",
|
||||
type=svc.RBACRoleType.WORKSPACE.value,
|
||||
category="global_system_default",
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
]
|
||||
|
||||
page_number = options.page_number if options and options.page_number is not None else 1
|
||||
results_per_page = options.results_per_page if options and options.results_per_page is not None else len(legacy_roles)
|
||||
reverse = options.reverse if options and options.reverse is not None else False
|
||||
|
||||
ordered_roles = list(reversed(legacy_roles)) if reverse else legacy_roles
|
||||
start = max(page_number - 1, 0) * results_per_page
|
||||
end = start + results_per_page
|
||||
paged_roles = ordered_roles[start:end]
|
||||
total_count = len(legacy_roles)
|
||||
total_pages = (total_count + results_per_page - 1) // results_per_page if results_per_page > 0 else 0
|
||||
|
||||
return svc.Paginated[svc.RBACRole](
|
||||
data=paged_roles,
|
||||
pagination=svc.Pagination(
|
||||
total_count=total_count,
|
||||
per_page=results_per_page,
|
||||
current_page=page_number,
|
||||
total_pages=total_pages,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission catalogs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
|
||||
class RBACWorkspaceCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
|
||||
class RBACAppCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
|
||||
class RBACDatasetCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Roles.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RoleUpsertRequest(BaseModel):
|
||||
"""Accepts the payload sent by the Create/Edit Role dialog."""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
def to_mutation(self) -> svc.RoleMutation:
|
||||
return svc.RoleMutation(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
permission_keys=list(self.permission_keys),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles")
|
||||
class RBACRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return _dump(_legacy_workspace_roles(options))
|
||||
return _dump(svc.RBACService.Roles.list(tenant_id, account_id, options=options))
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation())
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
|
||||
class RBACRoleItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation())
|
||||
return _dump(role)
|
||||
|
||||
@login_required
|
||||
def delete(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/copy")
|
||||
class RBACRoleCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
role = svc.RBACService.Roles.copy(tenant_id, account_id, str(role_id))
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access policies (tenant-level permission sets).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AccessPolicyCreateRequest(BaseModel):
|
||||
name: str
|
||||
resource_type: svc.RBACResourceType
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
class _AccessPolicyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies")
|
||||
class RBACAccessPoliciesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
# `resource_type` is exposed as a query argument so the UI can show
|
||||
# only app-scoped or only dataset-scoped permission sets.
|
||||
resource_type = request.args.get("resource_type") or None
|
||||
return _dump(
|
||||
svc.RBACService.AccessPolicies.list(
|
||||
tenant_id,
|
||||
account_id,
|
||||
resource_type=resource_type,
|
||||
options=_pagination_options(),
|
||||
)
|
||||
)
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyCreateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.create(
|
||||
tenant_id,
|
||||
account_id,
|
||||
svc.AccessPolicyCreate(
|
||||
name=request.name,
|
||||
resource_type=request.resource_type,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
|
||||
class RBACAccessPolicyItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyUpdateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.update(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.AccessPolicyUpdate(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy)
|
||||
|
||||
@login_required
|
||||
def delete(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
|
||||
class RBACAccessPolicyCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-app access (App Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceBindingsRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
account_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", "account_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bindings(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/my-permissions")
|
||||
class RBACMyPermissionsApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.MyPermissions.get(
|
||||
tenant_id,
|
||||
account_id,
|
||||
app_id=request.args.get("app_id") or None,
|
||||
dataset_id=request.args.get("dataset_id") or None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
|
||||
class RBACAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(app_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-dataset access (Knowledge Base Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
|
||||
class RBACDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_role_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(dataset_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/member-bindings"
|
||||
)
|
||||
class RBACDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_member_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace-level access (Settings > Access Rules).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
|
||||
class RBACWorkspaceAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_app_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
|
||||
class RBACWorkspaceDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Member ↔ role bindings (Settings > Members > Assign roles).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceMemberRolesRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_role_ids(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
|
||||
class RBACMemberRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceMemberRolesRequest)
|
||||
return _dump(
|
||||
svc.RBACService.MemberRoles.replace(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(member_id),
|
||||
role_ids=list(request.role_ids),
|
||||
)
|
||||
)
|
||||
@ -80,6 +80,7 @@ app_detail_fields = {
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"permission_keys": fields.List(fields.String),
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
@ -117,6 +118,7 @@ app_partial_fields = {
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
"permission_keys": fields.List(fields.String),
|
||||
}
|
||||
|
||||
|
||||
@ -197,6 +199,7 @@ app_detail_fields_with_site = {
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"permission_keys": fields.List(fields.String),
|
||||
"site": fields.Nested(site_fields),
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ dataset_fields = {
|
||||
"indexing_technique": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"permission_keys": fields.List(fields.String),
|
||||
}
|
||||
|
||||
reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
|
||||
@ -107,6 +108,7 @@ dataset_detail_fields = {
|
||||
"total_available_documents": fields.Integer,
|
||||
"enable_api": fields.Boolean,
|
||||
"is_multimodal": fields.Boolean,
|
||||
"permission_keys": fields.List(fields.String),
|
||||
}
|
||||
|
||||
file_info_fields = {
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import computed_field, field_validator
|
||||
from pydantic import Field, computed_field, field_validator
|
||||
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -70,6 +70,7 @@ class AccountWithRole(_AccountAvatar):
|
||||
last_active_at: int | None = None
|
||||
created_at: int | None = None
|
||||
role: str
|
||||
roles: list[dict[str, str]] = Field(default_factory=list)
|
||||
status: str
|
||||
|
||||
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
|
||||
|
||||
@ -11,6 +11,8 @@ from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@ -187,10 +189,14 @@ class Account(UserMixin, TypeBase):
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
@property
|
||||
def is_admin_or_owner(self):
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return True
|
||||
return TenantAccountRole.is_privileged_role(self.role)
|
||||
|
||||
@property
|
||||
def is_admin(self):
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return True
|
||||
return TenantAccountRole.is_admin_role(self.role)
|
||||
|
||||
@property
|
||||
@ -216,14 +222,20 @@ class Account(UserMixin, TypeBase):
|
||||
- `ADMIN`
|
||||
- `EDITOR`
|
||||
"""
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return True
|
||||
return TenantAccountRole.is_editing_role(self.role)
|
||||
|
||||
@property
|
||||
def is_dataset_editor(self):
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return True
|
||||
return TenantAccountRole.is_dataset_edit_role(self.role)
|
||||
|
||||
@property
|
||||
def is_dataset_operator(self):
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return True
|
||||
return self.role == TenantAccountRole.DATASET_OPERATOR
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
@ -16,6 +17,11 @@ from services.errors.enterprise import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Headers recognised by dify-enterprise's /inner/api/rbac/* endpoints.
|
||||
# Keep in sync with pkg/enterprise/service/rbac_inner_handlers.go.
|
||||
INNER_TENANT_ID_HEADER = "X-Inner-Tenant-Id"
|
||||
INNER_ACCOUNT_ID_HEADER = "X-Inner-Account-Id"
|
||||
|
||||
|
||||
class BaseRequest:
|
||||
proxies: Mapping[str, str] | None = {
|
||||
@ -49,8 +55,16 @@ class BaseRequest:
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
extra_headers: Mapping[str, str] | None = None,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
if extra_headers:
|
||||
# Explicitly ignore empty values so callers can pass optional
|
||||
# headers (e.g. `X-Inner-Account-Id`) without having to branch.
|
||||
for key, value in extra_headers.items():
|
||||
if value is None or value == "":
|
||||
continue
|
||||
headers[key] = value
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
mounts = cls._build_mounts()
|
||||
|
||||
@ -119,9 +133,56 @@ class BaseRequest:
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
rbac_base_url = os.environ.get("ENTERPRISE_RBAC_API_URL", base_url)
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
secret_key_header = "Enterprise-Api-Secret-Key"
|
||||
|
||||
@classmethod
|
||||
def send_inner_rbac_request(
|
||||
cls,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
account_id: str | None = None,
|
||||
json: Any | None = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
) -> Any:
|
||||
"""Call an /inner/api/rbac/* endpoint on dify-enterprise.
|
||||
|
||||
Inner RBAC endpoints require three headers on top of the standard
|
||||
Enterprise-Api-Secret-Key: the tenant the call targets and (optionally)
|
||||
the account acting on behalf of the workspace. This helper centralises
|
||||
both the assertions and the header wiring so callers only have to
|
||||
supply business payload.
|
||||
"""
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id must be provided for inner RBAC requests")
|
||||
|
||||
inner_headers: dict[str, str] = {INNER_TENANT_ID_HEADER: tenant_id}
|
||||
if account_id:
|
||||
inner_headers[INNER_ACCOUNT_ID_HEADER] = account_id
|
||||
url = f"{cls.rbac_base_url}{endpoint}"
|
||||
mounts = cls._build_mounts()
|
||||
|
||||
try:
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
inner_headers = dict(inner_headers)
|
||||
inner_headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key, **inner_headers}}
|
||||
if timeout is not None:
|
||||
request_kwargs["timeout"] = timeout
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
return response.json()
|
||||
|
||||
|
||||
class EnterprisePluginManagerRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
|
||||
|
||||
1156
api/services/enterprise/rbac_service.py
Normal file
1156
api/services/enterprise/rbac_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -180,6 +180,7 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_creators_platform: bool = False
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
rbac_enabled: bool = False
|
||||
|
||||
|
||||
class FeatureService:
|
||||
@ -229,6 +230,7 @@ class FeatureService:
|
||||
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
system_features.app_dsl_version = CURRENT_APP_DSL_VERSION
|
||||
system_features.rbac_enabled = dify_config.RBAC_ENABLED
|
||||
|
||||
cls._fulfill_system_params_from_env(system_features)
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@ from typing import Any
|
||||
import pytest
|
||||
from flask.views import MethodView
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
# kombu references MethodView as a global when importing celery/kombu pools.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
@ -196,6 +198,7 @@ def test_app_partial_serialization_uses_aliases(app_models):
|
||||
create_user_name="Creator",
|
||||
author_name="Author",
|
||||
has_draft_trigger=True,
|
||||
permission_keys=["app.acl.view_layout"],
|
||||
)
|
||||
|
||||
serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
|
||||
@ -208,6 +211,7 @@ def test_app_partial_serialization_uses_aliases(app_models):
|
||||
assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"}
|
||||
assert serialized["workflow"]["id"] == "wf-1"
|
||||
assert serialized["tags"][0]["name"] == "Utilities"
|
||||
assert serialized["permission_keys"] == ["app.acl.view_layout"]
|
||||
|
||||
|
||||
def test_app_detail_with_site_includes_nested_serialization(app_models):
|
||||
@ -271,6 +275,7 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
icon="first-icon",
|
||||
created_at=_ts(15),
|
||||
updated_at=_ts(15),
|
||||
permission_keys=["app.acl.edit"],
|
||||
)
|
||||
item_two = SimpleNamespace(
|
||||
id="app-11",
|
||||
@ -298,3 +303,52 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
assert len(serialized["data"]) == 2
|
||||
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
|
||||
assert serialized["data"][1]["icon_url"] is None
|
||||
assert serialized["data"][0]["permission_keys"] == ["app.acl.edit"]
|
||||
|
||||
|
||||
def test_app_list_api_attaches_permission_keys(app, app_module):
|
||||
method = app_module.AppListApi.get
|
||||
while hasattr(method, "__wrapped__"):
|
||||
method = method.__wrapped__
|
||||
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="List App",
|
||||
desc_or_prompt="Summary",
|
||||
mode_compatible_with_agent="chat",
|
||||
mode="chat",
|
||||
created_at=_ts(15),
|
||||
updated_at=_ts(15),
|
||||
permission_keys=[],
|
||||
)
|
||||
pagination = SimpleNamespace(page=1, per_page=20, total=1, has_next=False, items=[app_obj])
|
||||
|
||||
with app.test_request_context("/apps"):
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(dify_config, "RBAC_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
app_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="acct-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_module.AppService,
|
||||
"get_paginate_apps",
|
||||
lambda self, user_id, tenant_id, args_dict: pagination,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_module.FeatureService,
|
||||
"get_system_features",
|
||||
lambda: SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
app_module.enterprise_rbac_service.RBACService.AppPermissions,
|
||||
"batch_get",
|
||||
lambda tenant_id, account_id, app_ids: {"app-1": ["app.acl.view_layout", "app.acl.edit"]},
|
||||
)
|
||||
|
||||
resp, status = method(app_module.AppListApi())
|
||||
|
||||
assert status == 200
|
||||
assert app_obj.permission_keys == ["app.acl.view_layout", "app.acl.edit"]
|
||||
assert resp["data"][0]["permission_keys"] == ["app.acl.view_layout", "app.acl.edit"]
|
||||
|
||||
@ -93,6 +93,48 @@ class TestDatasetList:
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["embedding_available"] is True
|
||||
|
||||
def test_get_with_rbac_enabled_fetches_permission_keys(self, app):
|
||||
api = DatasetListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
current_user = self._mock_user()
|
||||
current_user.id = "acct-1"
|
||||
dataset = MagicMock(id="ds-1")
|
||||
datasets = [dataset]
|
||||
marshaled = [self._mock_dataset_dict()]
|
||||
|
||||
with app.test_request_context("/datasets"):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.datasets.datasets.dify_config.RBAC_ENABLED", True),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_datasets",
|
||||
return_value=(datasets, 1),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.enterprise_rbac_service.RBACService.DatasetPermissions.batch_get",
|
||||
return_value={"ds-1": ["dataset.acl.readonly", "dataset.acl.edit"]},
|
||||
) as mock_batch_get,
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.marshal",
|
||||
return_value=marshaled,
|
||||
),
|
||||
patch.object(
|
||||
ProviderManager,
|
||||
"get_configurations",
|
||||
return_value=MagicMock(get_models=lambda **_: []),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert dataset.permission_keys == ["dataset.acl.readonly", "dataset.acl.edit"]
|
||||
mock_batch_get.assert_called_once_with("tenant-1", "acct-1", ["ds-1"])
|
||||
|
||||
def test_get_with_ids_filter(self, app):
|
||||
api = DatasetListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -45,8 +46,8 @@ class TestMemberListApi:
|
||||
member.name = "Member"
|
||||
member.email = "member@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "admin"
|
||||
member.status = "active"
|
||||
member.current_role = SimpleNamespace(value="admin")
|
||||
member.status = SimpleNamespace(value="active")
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
@ -58,6 +59,53 @@ class TestMemberListApi:
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
assert result["accounts"][0]["role"] == "admin"
|
||||
assert result["accounts"][0]["roles"] == [{"id": "admin", "name": "admin"}]
|
||||
|
||||
def test_get_with_rbac_enabled_fetches_roles_in_batch(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="tenant-1")
|
||||
user = MagicMock(id="acct-1", current_tenant=tenant)
|
||||
member = SimpleNamespace(
|
||||
id="m1",
|
||||
name="Member",
|
||||
email="member@test.com",
|
||||
avatar=None,
|
||||
last_login_at=1,
|
||||
last_active_at=2,
|
||||
created_at=3,
|
||||
current_role=SimpleNamespace(value="editor"),
|
||||
status=SimpleNamespace(value="active"),
|
||||
)
|
||||
role_item = SimpleNamespace(
|
||||
account_id="m1",
|
||||
roles=[
|
||||
SimpleNamespace(id="workspace.owner", name="Owner"),
|
||||
SimpleNamespace(id="workspace.editor", name="Editor"),
|
||||
],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "tenant-1")),
|
||||
patch("controllers.console.workspace.members.dify_config.RBAC_ENABLED", True),
|
||||
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=[member]),
|
||||
patch(
|
||||
"controllers.console.workspace.members.enterprise_rbac_service.RBACService.MemberRoles.batch_get",
|
||||
return_value=[role_item],
|
||||
) as mock_batch_get,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["accounts"][0]["role"] == "editor"
|
||||
assert result["accounts"][0]["roles"] == [
|
||||
{"id": "workspace.owner", "name": "Owner"},
|
||||
{"id": "workspace.editor", "name": "Editor"},
|
||||
]
|
||||
mock_batch_get.assert_called_once_with("tenant-1", "acct-1", ["m1"])
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = MemberListApi()
|
||||
|
||||
270
api/tests/unit_tests/controllers/console/workspace/test_rbac.py
Normal file
270
api/tests/unit_tests/controllers/console/workspace/test_rbac.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Controller tests for ``controllers.console.workspace.rbac``.
|
||||
|
||||
The controllers here are thin: almost every non-trivial behaviour lives in
|
||||
``services.enterprise.rbac_service`` (covered by its own suite). These tests
|
||||
therefore focus on the Flask-layer concerns the service layer cannot exercise:
|
||||
|
||||
* ``_current_ids`` raises 404 when the session has no tenant.
|
||||
* The pydantic request models accept / reject bodies as expected.
|
||||
|
||||
We explicitly avoid "happy-path" integration tests through the full
|
||||
decorator stack — those belong in e2e tests where a real Dify session is
|
||||
available — to keep this suite fast and resilient to ancillary auth wiring
|
||||
changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
import inspect
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.workspace import rbac as rbac_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
def _enabled(enabled: bool):
|
||||
return patch("controllers.console.workspace.rbac.dify_config.ENTERPRISE_ENABLED", enabled)
|
||||
|
||||
|
||||
class TestCurrentIds:
|
||||
def test_rejects_missing_tenant(self):
|
||||
with patch("controllers.console.workspace.rbac.current_account_with_tenant") as mock_user:
|
||||
mock_user.return_value = (SimpleNamespace(id="acct-1"), None)
|
||||
with pytest.raises(NotFound):
|
||||
rbac_mod._current_ids()
|
||||
|
||||
def test_returns_tuple(self):
|
||||
with patch("controllers.console.workspace.rbac.current_account_with_tenant") as mock_user:
|
||||
mock_user.return_value = (SimpleNamespace(id="acct-1"), "tenant-1")
|
||||
assert rbac_mod._current_ids() == ("tenant-1", "acct-1")
|
||||
|
||||
|
||||
class TestPydanticModels:
|
||||
"""The internal `_…Request` models are the contract between the browser
|
||||
and the controllers. We only check non-obvious branches (enum parsing,
|
||||
missing required fields) — trivial `str` fields are not worth asserting.
|
||||
"""
|
||||
|
||||
def test_role_upsert_requires_name(self):
|
||||
with pytest.raises(ValidationError):
|
||||
rbac_mod._RoleUpsertRequest.model_validate({})
|
||||
|
||||
def test_role_upsert_to_mutation_preserves_fields(self):
|
||||
payload = rbac_mod._RoleUpsertRequest.model_validate(
|
||||
{
|
||||
"name": "Owner",
|
||||
"description": "full access",
|
||||
"permission_keys": ["workspace.member.manage"],
|
||||
}
|
||||
)
|
||||
mutation = payload.to_mutation()
|
||||
assert mutation.description == "full access"
|
||||
assert mutation.permission_keys == ["workspace.member.manage"]
|
||||
|
||||
def test_access_policy_create_parses_resource_type_enum(self):
|
||||
parsed = rbac_mod._AccessPolicyCreateRequest.model_validate(
|
||||
{
|
||||
"name": "Full access",
|
||||
"resource_type": "app",
|
||||
"description": "",
|
||||
"permission_keys": [],
|
||||
}
|
||||
)
|
||||
assert parsed.resource_type is rbac_mod.svc.RBACResourceType.APP
|
||||
|
||||
def test_access_policy_create_rejects_unknown_resource_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
rbac_mod._AccessPolicyCreateRequest.model_validate({"name": "bad", "resource_type": "unknown"})
|
||||
|
||||
def test_replace_bindings_defaults_empty(self):
|
||||
parsed = rbac_mod._ReplaceBindingsRequest.model_validate({})
|
||||
assert parsed.role_ids == []
|
||||
assert parsed.account_ids == []
|
||||
|
||||
def test_replace_bindings_coerce_null_lists(self):
|
||||
parsed = rbac_mod._ReplaceBindingsRequest.model_validate({"role_ids": None, "account_ids": None})
|
||||
assert parsed.role_ids == []
|
||||
assert parsed.account_ids == []
|
||||
|
||||
def test_replace_member_roles_coerce_null_list(self):
|
||||
parsed = rbac_mod._ReplaceMemberRolesRequest.model_validate({"role_ids": None})
|
||||
assert parsed.role_ids == []
|
||||
|
||||
def test_pagination_query_accepts_page_and_limit_aliases(self):
|
||||
parsed = rbac_mod._PaginationQuery.model_validate({"page": 3, "limit": 25, "reverse": True})
|
||||
assert parsed.page_number == 3
|
||||
assert parsed.results_per_page == 25
|
||||
assert parsed.reverse is True
|
||||
|
||||
def test_pagination_query_accepts_legacy_inner_names(self):
|
||||
parsed = rbac_mod._PaginationQuery.model_validate(
|
||||
{"page_number": 4, "results_per_page": 30, "reverse": False}
|
||||
)
|
||||
assert parsed.page_number == 4
|
||||
assert parsed.results_per_page == 30
|
||||
assert parsed.reverse is False
|
||||
|
||||
|
||||
class TestPaginationMapping:
|
||||
def test_roles_get_returns_legacy_compatible_roles_when_rbac_disabled(self, app):
|
||||
with (
|
||||
app.test_request_context("/workspaces/current/rbac/roles?page=1&limit=2"),
|
||||
patch("controllers.console.workspace.rbac.dify_config.RBAC_ENABLED", False),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.list") as mock_list,
|
||||
):
|
||||
response = inspect.unwrap(rbac_mod.RBACRolesApi.get)(rbac_mod.RBACRolesApi())
|
||||
|
||||
assert response["data"] == [
|
||||
{
|
||||
"id": "owner",
|
||||
"tenant_id": "",
|
||||
"type": "workspace",
|
||||
"category": "global_system_default",
|
||||
"name": "owner",
|
||||
"description": "",
|
||||
"is_builtin": True,
|
||||
"permission_keys": [
|
||||
"workspace.member.manage",
|
||||
"workspace.role.manage",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
"app.acl.access_config",
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.use",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "admin",
|
||||
"tenant_id": "",
|
||||
"type": "workspace",
|
||||
"category": "global_system_default",
|
||||
"name": "admin",
|
||||
"description": "",
|
||||
"is_builtin": True,
|
||||
"permission_keys": [
|
||||
"workspace.member.manage",
|
||||
"workspace.role.manage",
|
||||
"app.acl.view_layout",
|
||||
"app.acl.test_and_run",
|
||||
"app.acl.edit",
|
||||
"app.acl.access_config",
|
||||
"dataset.acl.readonly",
|
||||
"dataset.acl.edit",
|
||||
"dataset.acl.use",
|
||||
],
|
||||
},
|
||||
]
|
||||
assert response["pagination"] == {
|
||||
"total_count": 5,
|
||||
"per_page": 2,
|
||||
"current_page": 1,
|
||||
"total_pages": 3,
|
||||
}
|
||||
mock_list.assert_not_called()
|
||||
|
||||
def test_roles_get_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context("/workspaces/current/rbac/roles?page=2&limit=50&reverse=true"),
|
||||
patch("controllers.console.workspace.rbac.dify_config.RBAC_ENABLED", True),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.list") as mock_list,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACRolesApi.get)(rbac_mod.RBACRolesApi())
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
options = kwargs["options"]
|
||||
assert options.page_number == 2
|
||||
assert options.results_per_page == 50
|
||||
assert options.reverse is True
|
||||
|
||||
def test_access_policies_get_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/workspaces/current/rbac/access-policies?resource_type=app&page=3&limit=25&reverse=false"
|
||||
),
|
||||
_enabled(True),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.AccessPolicies.list") as mock_list,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACAccessPoliciesApi.get)(rbac_mod.RBACAccessPoliciesApi())
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs["resource_type"] == "app"
|
||||
options = kwargs["options"]
|
||||
assert options.page_number == 3
|
||||
assert options.results_per_page == 25
|
||||
assert options.reverse is False
|
||||
|
||||
def test_workspace_app_matrix_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context("/workspaces/current/rbac/workspace/apps/access-policy?page=4&limit=10"),
|
||||
_enabled(True),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.WorkspaceAccess.app_matrix") as mock_list,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACWorkspaceAppMatrixApi.get)(rbac_mod.RBACWorkspaceAppMatrixApi())
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
options = kwargs["options"]
|
||||
assert options.page_number == 4
|
||||
assert options.results_per_page == 10
|
||||
assert options.reverse is None
|
||||
|
||||
def test_workspace_dataset_matrix_forwards_outer_pagination_params(self, app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/workspaces/current/rbac/workspace/datasets/access-policy?page=5&limit=15&reverse=true"
|
||||
),
|
||||
_enabled(True),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.WorkspaceAccess.dataset_matrix")
|
||||
as mock_list,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACWorkspaceDatasetMatrixApi.get)(rbac_mod.RBACWorkspaceDatasetMatrixApi())
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
options = kwargs["options"]
|
||||
assert options.page_number == 5
|
||||
assert options.results_per_page == 15
|
||||
assert options.reverse is True
|
||||
|
||||
|
||||
class TestRoleCopy:
|
||||
def test_role_copy_forwards_path_id(self, app):
|
||||
with (
|
||||
app.test_request_context("/workspaces/current/rbac/roles/role-1/copy", method="POST"),
|
||||
_enabled(True),
|
||||
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
|
||||
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.copy") as mock_copy,
|
||||
patch("controllers.console.workspace.rbac._dump", return_value={}),
|
||||
):
|
||||
inspect.unwrap(rbac_mod.RBACRoleCopyApi.post)(rbac_mod.RBACRoleCopyApi(), "role-1")
|
||||
|
||||
mock_copy.assert_called_once_with("tenant-1", "acct-1", "role-1")
|
||||
|
||||
|
||||
class TestDumpHelper:
|
||||
def test_dump_returns_plain_dict(self):
|
||||
role = rbac_mod.svc.RBACRole(id="role-1", type="workspace", name="Owner")
|
||||
dumped = rbac_mod._dump(role)
|
||||
assert isinstance(dumped, dict)
|
||||
assert "role_id" not in dumped
|
||||
@ -1,14 +1,12 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import UserDict
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from docx import Document
|
||||
@ -377,23 +375,21 @@ def test_close_is_idempotent():
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_close_handles_async_close_mock():
|
||||
async def _async_close() -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_close_closes_awaitable_close_result():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
extractor.temp_file.close = AsyncMock()
|
||||
close_result = _async_close()
|
||||
extractor.temp_file.close = MagicMock(return_value=close_result)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
extractor.close()
|
||||
gc.collect()
|
||||
extractor.close()
|
||||
|
||||
assert close_result.cr_frame is None
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
assert not [
|
||||
warning
|
||||
for warning in caught
|
||||
if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
|
||||
]
|
||||
|
||||
|
||||
def test_extract_images_handles_invalid_external_cases(monkeypatch):
|
||||
|
||||
@ -13,6 +13,7 @@ import base64
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -347,7 +348,15 @@ class TestAccountRolePermissions:
|
||||
account.role = TenantAccountRole.ADMIN
|
||||
|
||||
# Act & Assert
|
||||
assert account.is_admin_or_owner
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert account.is_admin_or_owner
|
||||
|
||||
def test_is_admin_or_owner_with_rbac_enabled(self):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", True):
|
||||
assert account.is_admin_or_owner
|
||||
|
||||
def test_is_admin_or_owner_with_owner_role(self):
|
||||
"""Test is_admin_or_owner property with owner role."""
|
||||
@ -383,8 +392,16 @@ class TestAccountRolePermissions:
|
||||
owner_account.role = TenantAccountRole.OWNER
|
||||
|
||||
# Act & Assert
|
||||
assert admin_account.is_admin
|
||||
assert not owner_account.is_admin
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert admin_account.is_admin
|
||||
assert not owner_account.is_admin
|
||||
|
||||
def test_is_admin_with_rbac_enabled(self):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", True):
|
||||
assert account.is_admin
|
||||
|
||||
def test_has_edit_permission_with_editing_roles(self):
|
||||
"""Test has_edit_permission property with roles that have edit permission."""
|
||||
@ -400,7 +417,15 @@ class TestAccountRolePermissions:
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert account.has_edit_permission, f"Role {role} should have edit permission"
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert account.has_edit_permission, f"Role {role} should have edit permission"
|
||||
|
||||
def test_has_edit_permission_with_rbac_enabled(self):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", True):
|
||||
assert account.has_edit_permission
|
||||
|
||||
def test_has_edit_permission_without_editing_roles(self):
|
||||
"""Test has_edit_permission property with roles that don't have edit permission."""
|
||||
@ -415,7 +440,8 @@ class TestAccountRolePermissions:
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert not account.has_edit_permission, f"Role {role} should not have edit permission"
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert not account.has_edit_permission, f"Role {role} should not have edit permission"
|
||||
|
||||
def test_is_dataset_editor_property(self):
|
||||
"""Test is_dataset_editor property."""
|
||||
@ -432,12 +458,21 @@ class TestAccountRolePermissions:
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert account.is_dataset_editor, f"Role {role} should have dataset edit permission"
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert account.is_dataset_editor, f"Role {role} should have dataset edit permission"
|
||||
|
||||
# Test normal role doesn't have dataset edit permission
|
||||
normal_account = Account(name="Normal User", email="normal@example.com")
|
||||
normal_account.role = TenantAccountRole.NORMAL
|
||||
assert not normal_account.is_dataset_editor
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert not normal_account.is_dataset_editor
|
||||
|
||||
def test_is_dataset_editor_with_rbac_enabled(self):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", True):
|
||||
assert account.is_dataset_editor
|
||||
|
||||
def test_is_dataset_operator_property(self):
|
||||
"""Test is_dataset_operator property."""
|
||||
@ -449,8 +484,16 @@ class TestAccountRolePermissions:
|
||||
normal_account.role = TenantAccountRole.NORMAL
|
||||
|
||||
# Act & Assert
|
||||
assert dataset_operator.is_dataset_operator
|
||||
assert not normal_account.is_dataset_operator
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", False):
|
||||
assert dataset_operator.is_dataset_operator
|
||||
assert not normal_account.is_dataset_operator
|
||||
|
||||
def test_is_dataset_operator_with_rbac_enabled(self):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
with patch("models.account.dify_config.RBAC_ENABLED", True):
|
||||
assert account.is_dataset_operator
|
||||
|
||||
def test_current_role_property(self):
|
||||
"""Test current_role property."""
|
||||
|
||||
568
api/tests/unit_tests/services/enterprise/test_rbac_service.py
Normal file
568
api/tests/unit_tests/services/enterprise/test_rbac_service.py
Normal file
@ -0,0 +1,568 @@
|
||||
"""Unit tests for services.enterprise.rbac_service.
|
||||
|
||||
The enterprise RBAC client is almost pure glue: each method turns a single
|
||||
``EnterpriseRequest.send_inner_rbac_request`` call into a pydantic response
|
||||
model. Rather than spinning up an HTTP server we monkeypatch that helper and
|
||||
assert on the arguments it received; that catches both routing regressions
|
||||
(wrong method / wrong path / wrong params) and model-shape regressions in
|
||||
one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise import rbac_service as svc
|
||||
|
||||
MODULE = "services.enterprise.rbac_service"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_send():
|
||||
with patch(f"{MODULE}.EnterpriseRequest.send_inner_rbac_request") as send:
|
||||
yield send
|
||||
|
||||
|
||||
def _call_args(send: MagicMock) -> SimpleNamespace:
|
||||
"""Return the most recent (method, endpoint, kwargs) sent to the mock."""
|
||||
send.assert_called_once()
|
||||
args, kwargs = send.call_args
|
||||
return SimpleNamespace(method=args[0], endpoint=args[1], **kwargs)
|
||||
|
||||
|
||||
class TestCatalog:
|
||||
def test_workspace_catalog(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"groups": [{"group_key": "workspace", "group_name": "工作空间", "permissions": []}]}
|
||||
|
||||
out = svc.RBACService.Catalog.workspace("tenant-1", account_id="acct-1")
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/role-permissions/catalog"
|
||||
assert call.tenant_id == "tenant-1"
|
||||
assert call.account_id == "acct-1"
|
||||
assert call.json is None
|
||||
assert call.params is None
|
||||
assert len(out.groups) == 1
|
||||
assert out.groups[0].group_key == "workspace"
|
||||
|
||||
def test_app_catalog_endpoint(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"groups": []}
|
||||
svc.RBACService.Catalog.app("tenant-1")
|
||||
assert mock_send.call_args.args[1] == "/rbac/role-permissions/catalog/app"
|
||||
|
||||
def test_dataset_catalog_endpoint(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"groups": []}
|
||||
svc.RBACService.Catalog.dataset("tenant-1")
|
||||
assert mock_send.call_args.args[1] == "/rbac/role-permissions/catalog/dataset"
|
||||
|
||||
|
||||
class TestRoles:
|
||||
def test_list_forwards_pagination_options(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "role-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"type": "workspace",
|
||||
"category": "global_custom",
|
||||
"name": "Owner",
|
||||
"permission_keys": ["workspace.member.manage"],
|
||||
}
|
||||
],
|
||||
"pagination": {"total_count": 1, "per_page": 20, "current_page": 1, "total_pages": 1},
|
||||
}
|
||||
|
||||
out = svc.RBACService.Roles.list(
|
||||
"tenant-1",
|
||||
"acct-1",
|
||||
options=svc.ListOption(page_number=2, results_per_page=50, reverse=True),
|
||||
)
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/roles"
|
||||
assert call.params == {"page_number": 2, "results_per_page": 50, "reverse": "true"}
|
||||
assert out.pagination and out.pagination.total_count == 1
|
||||
|
||||
def test_list_omits_params_when_default(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": [], "pagination": None}
|
||||
svc.RBACService.Roles.list("tenant-1")
|
||||
assert _call_args(mock_send).params is None
|
||||
|
||||
def test_list_coerces_null_permission_keys(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "role-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"type": "workspace",
|
||||
"category": "global_custom",
|
||||
"name": "Owner",
|
||||
"permission_keys": None,
|
||||
}
|
||||
],
|
||||
"pagination": None,
|
||||
}
|
||||
|
||||
out = svc.RBACService.Roles.list("tenant-1")
|
||||
|
||||
assert out.data[0].permission_keys == []
|
||||
|
||||
def test_get_passes_id_query_param(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
|
||||
svc.RBACService.Roles.get("tenant-1", "acct-1", "role-1")
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/roles/item"
|
||||
assert call.params == {"id": "role-1"}
|
||||
|
||||
def test_create_sends_body(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
|
||||
payload = svc.RoleMutation(name="Owner", description="full access", permission_keys=["workspace.member.manage"])
|
||||
svc.RBACService.Roles.create("tenant-1", "acct-1", payload)
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/roles"
|
||||
assert call.json == {
|
||||
"name": "Owner",
|
||||
"description": "full access",
|
||||
"permission_keys": ["workspace.member.manage"],
|
||||
"type": "workspace",
|
||||
}
|
||||
|
||||
def test_update_sends_id_param_and_body(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
|
||||
payload = svc.RoleMutation(name="Owner", permission_keys=["x"])
|
||||
svc.RBACService.Roles.update("tenant-1", "acct-1", "role-1", payload)
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/roles/item"
|
||||
assert call.params == {"id": "role-1"}
|
||||
assert call.json == {"name": "Owner", "description": "", "permission_keys": ["x"], "type": "workspace"}
|
||||
|
||||
def test_delete_uses_delete_method(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"message": "success"}
|
||||
svc.RBACService.Roles.delete("tenant-1", None, "role-1")
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "DELETE"
|
||||
assert call.endpoint == "/rbac/roles/item"
|
||||
assert call.params == {"id": "role-1"}
|
||||
assert call.account_id is None
|
||||
|
||||
def test_copy_sends_post_with_id_param(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"id": "role-1-copy", "type": "workspace", "name": "Owner copy"}
|
||||
svc.RBACService.Roles.copy("tenant-1", "acct-1", "role-1")
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/roles/copy"
|
||||
assert call.params == {"id": "role-1"}
|
||||
assert call.account_id == "acct-1"
|
||||
|
||||
|
||||
class TestAccessPolicies:
|
||||
def test_list_filters_by_resource_type(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": [], "pagination": None}
|
||||
svc.RBACService.AccessPolicies.list(
|
||||
"tenant-1",
|
||||
"acct-1",
|
||||
resource_type=svc.RBACResourceType.APP,
|
||||
options=svc.ListOption(page_number=1),
|
||||
)
|
||||
call = _call_args(mock_send)
|
||||
assert call.endpoint == "/rbac/access-policies"
|
||||
assert call.params == {"page_number": 1, "resource_type": "app"}
|
||||
|
||||
def test_copy_sends_post_with_id_param(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"id": "policy-1-copy",
|
||||
"resource_type": "app",
|
||||
"name": "Full access copy",
|
||||
}
|
||||
svc.RBACService.AccessPolicies.copy("tenant-1", "acct-1", "policy-1")
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/access-policies/copy"
|
||||
assert call.params == {"id": "policy-1"}
|
||||
|
||||
def test_create_serialises_resource_type_enum(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"id": "policy-1", "resource_type": "dataset", "name": "KB only"}
|
||||
payload = svc.AccessPolicyCreate(
|
||||
name="KB only",
|
||||
resource_type=svc.RBACResourceType.DATASET,
|
||||
permission_keys=["dataset.acl.readonly"],
|
||||
)
|
||||
svc.RBACService.AccessPolicies.create("tenant-1", "acct-1", payload)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.json == {
|
||||
"name": "KB only",
|
||||
"resource_type": "dataset",
|
||||
"description": "",
|
||||
"permission_keys": ["dataset.acl.readonly"],
|
||||
}
|
||||
|
||||
|
||||
class TestResourceAccess:
|
||||
def test_app_matrix(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"app_id": "app-1", "items": []}
|
||||
out = svc.RBACService.AppAccess.matrix("tenant-1", "acct-1", "app-1")
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/apps/access-policy"
|
||||
assert call.params == {"app_id": "app-1"}
|
||||
assert out.app_id == "app-1"
|
||||
|
||||
def test_app_role_bindings_preserve_role_name(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "binding-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"access_policy_id": "policy-1",
|
||||
"resource_type": "app",
|
||||
"resource_id": "app-1",
|
||||
"role_id": "role-1",
|
||||
"role_name": "Owner",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
out = svc.RBACService.AppAccess.list_role_bindings("tenant-1", "acct-1", "app-1", "policy-1")
|
||||
|
||||
assert out.data[0].role_name == "Owner"
|
||||
|
||||
def test_app_member_bindings_preserve_account_name(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "binding-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"access_policy_id": "policy-1",
|
||||
"resource_type": "app",
|
||||
"resource_id": "app-1",
|
||||
"account_id": "acct-1",
|
||||
"account_name": "Alice",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
out = svc.RBACService.AppAccess.list_member_bindings("tenant-1", "acct-1", "app-1", "policy-1")
|
||||
|
||||
assert out.data[0].account_name == "Alice"
|
||||
|
||||
def test_app_replace_bindings(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": []}
|
||||
payload = svc.ReplaceBindings(role_ids=["workspace.owner"], account_ids=["acct-2"])
|
||||
svc.RBACService.AppAccess.replace_bindings("tenant-1", "acct-1", "app-1", "policy-1", payload)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/apps/access-policy/bindings"
|
||||
assert call.params == {"app_id": "app-1", "policy_id": "policy-1"}
|
||||
assert call.json == {"role_ids": ["workspace.owner"], "account_ids": ["acct-2"]}
|
||||
|
||||
def test_dataset_replace_bindings(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": []}
|
||||
payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"])
|
||||
svc.RBACService.DatasetAccess.replace_bindings("tenant-1", "acct-1", "ds-1", "policy-1", payload)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/datasets/access-policy/bindings"
|
||||
assert call.params == {"dataset_id": "ds-1", "policy_id": "policy-1"}
|
||||
assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]}
|
||||
|
||||
|
||||
class TestWorkspaceAccess:
|
||||
def test_app_matrix(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"items": [], "pagination": {"total_count": 1, "per_page": 20, "current_page": 2, "total_pages": 1}}
|
||||
out = svc.RBACService.WorkspaceAccess.app_matrix(
|
||||
"tenant-1",
|
||||
options=svc.ListOption(page_number=2, results_per_page=20),
|
||||
)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/workspace/apps/access-policy"
|
||||
assert call.params == {"page_number": 2, "results_per_page": 20}
|
||||
assert out.pagination and out.pagination.current_page == 2
|
||||
|
||||
def test_dataset_matrix(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"items": []}
|
||||
svc.RBACService.WorkspaceAccess.dataset_matrix("tenant-1")
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/workspace/datasets/access-policy"
|
||||
assert call.params is None
|
||||
|
||||
def test_workspace_matrix_coerces_null_bindings(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"items": [
|
||||
{
|
||||
"policy": {
|
||||
"id": "policy-1",
|
||||
"resource_type": "app",
|
||||
"name": "Workspace App Access",
|
||||
},
|
||||
"roles": None,
|
||||
"accounts": None,
|
||||
}
|
||||
],
|
||||
"pagination": None,
|
||||
}
|
||||
|
||||
out = svc.RBACService.WorkspaceAccess.app_matrix("tenant-1")
|
||||
|
||||
assert out.items[0].roles == []
|
||||
assert out.items[0].accounts == []
|
||||
|
||||
def test_workspace_app_replace_bindings(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": []}
|
||||
payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"])
|
||||
svc.RBACService.WorkspaceAccess.replace_app_bindings(
|
||||
"tenant-1", "acct-1", "policy-1", payload
|
||||
)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/workspace/apps/access-policy/bindings"
|
||||
assert call.params == {"policy_id": "policy-1"}
|
||||
assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]}
|
||||
|
||||
def test_workspace_dataset_replace_bindings(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"data": []}
|
||||
payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"])
|
||||
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
|
||||
"tenant-1", "acct-1", "policy-1", payload
|
||||
)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/workspace/datasets/access-policy/bindings"
|
||||
assert call.params == {"policy_id": "policy-1"}
|
||||
assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]}
|
||||
|
||||
|
||||
class TestMyPermissions:
|
||||
def test_get_without_payload_uses_get(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"workspace": {"permission_keys": ["workspace.member.manage"]},
|
||||
"app": {"default_permission_keys": ["app.acl.view_layout", "app.acl.test_and_run"], "overrides": []},
|
||||
"dataset": {"default_permission_keys": [], "overrides": []},
|
||||
}
|
||||
|
||||
with patch(f"{MODULE}.dify_config.RBAC_ENABLED", True):
|
||||
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1")
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/my-permissions"
|
||||
assert call.json is None
|
||||
assert call.params is None
|
||||
assert out.workspace.permission_keys == ["workspace.member.manage"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "workspace_keys", "app_keys", "dataset_keys"),
|
||||
[
|
||||
(
|
||||
"owner",
|
||||
["workspace.member.manage", "workspace.role.manage"],
|
||||
["app.acl.view_layout", "app.acl.test_and_run", "app.acl.edit", "app.acl.access_config"],
|
||||
["dataset.acl.readonly", "dataset.acl.edit", "dataset.acl.use"],
|
||||
),
|
||||
(
|
||||
"admin",
|
||||
["workspace.member.manage", "workspace.role.manage"],
|
||||
["app.acl.view_layout", "app.acl.test_and_run", "app.acl.edit", "app.acl.access_config"],
|
||||
["dataset.acl.readonly", "dataset.acl.edit", "dataset.acl.use"],
|
||||
),
|
||||
(
|
||||
"editor",
|
||||
[],
|
||||
["app.acl.view_layout", "app.acl.test_and_run", "app.acl.edit", "app.acl.access_config"],
|
||||
["dataset.acl.readonly", "dataset.acl.edit", "dataset.acl.use"],
|
||||
),
|
||||
(
|
||||
"normal",
|
||||
[],
|
||||
["app.acl.view_layout", "app.acl.test_and_run"],
|
||||
[],
|
||||
),
|
||||
(
|
||||
"dataset_operator",
|
||||
[],
|
||||
[],
|
||||
["dataset.acl.readonly", "dataset.acl.edit", "dataset.acl.use"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_uses_legacy_role_permissions_when_rbac_disabled(
|
||||
self,
|
||||
mock_send: MagicMock,
|
||||
role: str,
|
||||
workspace_keys: list[str],
|
||||
app_keys: list[str],
|
||||
dataset_keys: list[str],
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = role
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config.RBAC_ENABLED", False),
|
||||
patch(f"{MODULE}.session_factory.create_session", return_value=mock_session),
|
||||
):
|
||||
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1")
|
||||
|
||||
mock_send.assert_not_called()
|
||||
assert out.workspace.permission_keys == workspace_keys
|
||||
assert out.app.default_permission_keys == app_keys
|
||||
assert out.dataset.default_permission_keys == dataset_keys
|
||||
assert out.app.overrides == []
|
||||
assert out.dataset.overrides == []
|
||||
|
||||
def test_get_returns_empty_when_role_missing_and_rbac_disabled(self, mock_send: MagicMock):
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config.RBAC_ENABLED", False),
|
||||
patch(f"{MODULE}.session_factory.create_session", return_value=mock_session),
|
||||
):
|
||||
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1")
|
||||
|
||||
mock_send.assert_not_called()
|
||||
assert out.workspace.permission_keys == []
|
||||
assert out.app.default_permission_keys == []
|
||||
assert out.dataset.default_permission_keys == []
|
||||
|
||||
def test_get_with_single_resource_filters(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"workspace": {"permission_keys": []},
|
||||
"app": {"default_permission_keys": [], "overrides": [{"resource_id": "app-1", "permission_keys": ["app.acl.edit"]}]},
|
||||
"dataset": {"default_permission_keys": [], "overrides": []},
|
||||
}
|
||||
|
||||
with patch(f"{MODULE}.dify_config.RBAC_ENABLED", True):
|
||||
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1", app_id="app-1")
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/my-permissions"
|
||||
assert call.params == {"app_id": "app-1"}
|
||||
assert out.app.overrides[0].resource_id == "app-1"
|
||||
|
||||
|
||||
class TestMemberRoles:
|
||||
def test_get(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"account_id": "acct-2",
|
||||
"roles": [
|
||||
{
|
||||
"id": "role-1",
|
||||
"type": "workspace",
|
||||
"name": "Member",
|
||||
}
|
||||
],
|
||||
}
|
||||
out = svc.RBACService.MemberRoles.get("tenant-1", "acct-1", "acct-2")
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "GET"
|
||||
assert call.endpoint == "/rbac/members/rbac-roles"
|
||||
assert call.params == {"account_id": "acct-2"}
|
||||
assert out.account_id == "acct-2"
|
||||
assert out.roles[0].name == "Member"
|
||||
|
||||
def test_replace(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {"account_id": "acct-2", "roles": []}
|
||||
svc.RBACService.MemberRoles.replace(
|
||||
"tenant-1", "acct-1", "acct-2", role_ids=["workspace.owner", "workspace.editor"]
|
||||
)
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "PUT"
|
||||
assert call.endpoint == "/rbac/members/rbac-roles"
|
||||
assert call.params == {"account_id": "acct-2"}
|
||||
assert call.json == {"role_ids": ["workspace.owner", "workspace.editor"]}
|
||||
|
||||
def test_batch_get(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"account_id": "acct-2",
|
||||
"roles": [
|
||||
{"id": "role-1", "name": "Admin"},
|
||||
{"id": "role-2", "name": "Editor"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"account_id": "acct-3",
|
||||
"roles": [],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
out = svc.RBACService.MemberRoles.batch_get("tenant-1", "acct-1", ["acct-2", "acct-3"])
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/members/rbac-roles/batch"
|
||||
assert call.json == {"account_ids": ["acct-2", "acct-3"]}
|
||||
assert out[0].account_id == "acct-2"
|
||||
assert len(out[0].roles) == 2
|
||||
|
||||
|
||||
class TestResourcePermissions:
|
||||
def test_app_permissions_batch_get(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{"resource_id": "app-1", "permission_keys": ["app.acl.view_layout", "app.acl.edit"]},
|
||||
{"resource_id": "app-2", "permission_keys": []},
|
||||
]
|
||||
}
|
||||
|
||||
out = svc.RBACService.AppPermissions.batch_get("tenant-1", "acct-1", ["app-1", "app-2"])
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/apps/permission-keys/batch"
|
||||
assert call.json == {"app_ids": ["app-1", "app-2"]}
|
||||
assert out == {
|
||||
"app-1": ["app.acl.view_layout", "app.acl.edit"],
|
||||
"app-2": [],
|
||||
}
|
||||
|
||||
def test_dataset_permissions_batch_get(self, mock_send: MagicMock):
|
||||
mock_send.return_value = {
|
||||
"data": [
|
||||
{"resource_id": "ds-1", "permission_keys": ["dataset.acl.readonly"]},
|
||||
{"resource_id": "ds-2", "permission_keys": ["dataset.acl.edit"]},
|
||||
]
|
||||
}
|
||||
|
||||
out = svc.RBACService.DatasetPermissions.batch_get("tenant-1", "acct-1", ["ds-1", "ds-2"])
|
||||
|
||||
call = _call_args(mock_send)
|
||||
assert call.method == "POST"
|
||||
assert call.endpoint == "/rbac/datasets/permission-keys/batch"
|
||||
assert call.json == {"dataset_ids": ["ds-1", "ds-2"]}
|
||||
assert out == {
|
||||
"ds-1": ["dataset.acl.readonly"],
|
||||
"ds-2": ["dataset.acl.edit"],
|
||||
}
|
||||
|
||||
|
||||
class TestListOption:
|
||||
def test_empty_produces_empty_params(self):
|
||||
assert svc.ListOption().to_params() == {}
|
||||
|
||||
def test_reverse_serialises_as_lowercase_bool(self):
|
||||
assert svc.ListOption(reverse=False).to_params()["reverse"] == "false"
|
||||
assert svc.ListOption(reverse=True).to_params()["reverse"] == "true"
|
||||
|
||||
def test_extra_overrides_merge(self):
|
||||
assert svc.ListOption(page_number=1).to_params({"resource_type": "app", "skip": None}) == {
|
||||
"page_number": 1,
|
||||
"resource_type": "app",
|
||||
}
|
||||
1
dify-agent/.gitignore
vendored
1
dify-agent/.gitignore
vendored
@ -1 +0,0 @@
|
||||
dify-aio
|
||||
@ -1,184 +0,0 @@
|
||||
# Agent Guide
|
||||
|
||||
## Notes for Agent (must-check)
|
||||
|
||||
Before changing any source code under this folder, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec.
|
||||
|
||||
Look for:
|
||||
|
||||
- The module (file) docstring at the top of a source code file
|
||||
- Docstrings on classes and functions/methods
|
||||
- Paragraph/block comments for non-obvious logic
|
||||
|
||||
### What to write where
|
||||
|
||||
- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse.
|
||||
- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing.
|
||||
- Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard.
|
||||
- Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes.
|
||||
- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used).
|
||||
- If the class is intentionally stateful, note what state exists and what methods mutate it.
|
||||
- If concurrency/async assumptions matter, state them explicitly.
|
||||
- **Function/method docstring**: behavioural contract.
|
||||
- Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions.
|
||||
- Add examples only when they prevent misuse.
|
||||
- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states.
|
||||
- Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality.
|
||||
|
||||
### Rules (must follow)
|
||||
|
||||
In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments.
|
||||
|
||||
- **Before working**
|
||||
- Read the notes in the area you’ll touch; treat them as part of the spec.
|
||||
- If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality.
|
||||
- If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour).
|
||||
- **During working**
|
||||
- Keep the notes in sync as you discover constraints, make decisions, or change approach.
|
||||
- If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants.
|
||||
- Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct.
|
||||
- Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions.
|
||||
- **When finishing**
|
||||
- Update the notes to reflect what changed, why, and any new edge cases/tests.
|
||||
- Remove or rewrite any comments that could be mistaken as current guidance but no longer apply.
|
||||
- Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery.
|
||||
|
||||
## Coding Style
|
||||
|
||||
This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
|
||||
|
||||
### Linting & Formatting
|
||||
|
||||
- Use Ruff for formatting and linting (follow `.ruff.toml`).
|
||||
- Keep each line under 120 characters (including spaces).
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
- Use `snake_case` for variables and functions.
|
||||
- Use `PascalCase` for classes.
|
||||
- Use `UPPER_CASE` for constants.
|
||||
|
||||
### Typing & Class Layout
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Example:
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
|
||||
def __init__(self, user_id: str, created_at: datetime) -> None:
|
||||
self.user_id = user_id
|
||||
self.created_at = created_at
|
||||
```
|
||||
|
||||
- For dataclasses, prefer `field(default_factory=...)` over `field(init=False)` when a default can be provided declaratively.
|
||||
- Prefer dataclasses with `slots=True` when defining lightweight data containers:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Example:
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
```
|
||||
|
||||
### General Rules
|
||||
|
||||
- Use Pydantic v2 conventions.
|
||||
- Use `uv` for Python package management in this repo (usually with `--project dify-agent`).
|
||||
- Use `make typecheck` to run `basedpyright` against `dify-agent/src` and `dify-agent/tests`.
|
||||
- Keep type checking passing after every edit you make.
|
||||
- Use `pytest` for all tests in this package.
|
||||
- When integrating with, implementing, or mocking a dependency, inspect the dependency's source code to confirm its API shape and runtime behavior instead of guessing from names alone.
|
||||
- Prefer simple functions over small “utility classes” for lightweight helpers.
|
||||
- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns.
|
||||
- Keep code readable and explicit—avoid clever hacks.
|
||||
|
||||
### Testing
|
||||
|
||||
- Work in TDD style: write or update a failing test first when changing behavior, then make the implementation pass, then refactor while keeping tests and typecheck green.
|
||||
- Use `make test` to run the agent pytest suite.
|
||||
- Keep local tests under `dify-agent/tests/local/`.
|
||||
- Mirror the `dify-agent/src/` package structure inside `dify-agent/tests/local/` so test locations stay predictable.
|
||||
|
||||
#### Local Tests
|
||||
|
||||
- Write local tests for stable, externally observable behavior that can run quickly without real external services.
|
||||
- In this repo, code, comments, docs, and tests are expected to change together. Because of that, a local test is only useful if it would still be correct after an internal refactor that does not change the intended contract.
|
||||
- Local tests should verify:
|
||||
- what callers and downstream code can observe and rely on
|
||||
- how the unit is expected to use its dependencies at the boundary
|
||||
- how the unit handles dependency success, failure, empty responses, malformed responses, and documented error cases
|
||||
- documented invariants, error mapping, and output/input shape guarantees
|
||||
- When asserting dependency interactions, assert only the parts of the request or response that are part of the real boundary contract. Do not over-specify incidental details that callers or dependencies do not rely on.
|
||||
- It is acceptable to mock dependencies in local tests, but only when the mock represents a real contract, schema, documented behavior, or known regression.
|
||||
- Tests may use line-scoped type-ignore comments when intentionally exercising runtime validation paths that static typing would normally reject. Keep the ignore on the exact invalid call.
|
||||
- Do not use local tests to prove real integration, network wiring, serialization, framework configuration, or third-party runtime behavior; cover those in higher-level tests.
|
||||
- Meaningless local tests include:
|
||||
- tests that only mirror the current implementation or must be updated whenever internal code changes even though the contract did not change
|
||||
- tests of private helpers, local variables, temporary state, internal branching, or exact internal call order unless those details are part of the published contract
|
||||
- tests with mocked dependency behavior that is invented only to make the current implementation pass
|
||||
- tests that add no value beyond static type checking or linting
|
||||
|
||||
### Logging & Errors
|
||||
|
||||
- Never use `print`; use a module-level logger:
|
||||
- `logger = logging.getLogger(__name__)`
|
||||
- Include tenant/app/workflow identifiers in log context when relevant.
|
||||
- Raise domain-specific exceptions and translate them into HTTP responses in controllers.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
### Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
### Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
@ -1,35 +0,0 @@
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help check fix typecheck test update-examples docs docs-serve
|
||||
|
||||
help:
|
||||
@echo "Dify agent targets:"
|
||||
@echo " make check - Run Ruff for dify-agent"
|
||||
@echo " make fix - Format and fix Ruff issues"
|
||||
@echo " make typecheck - Run basedpyright for src, examples, and tests"
|
||||
@echo " make test - Run local tests and docs/example tests"
|
||||
@echo " make update-examples - Rewrite docs example outputs when needed"
|
||||
@echo " make docs - Build MkDocs documentation"
|
||||
@echo " make docs-serve - Serve MkDocs documentation locally"
|
||||
|
||||
check:
|
||||
@uv run --project . python -m ruff check .
|
||||
|
||||
fix:
|
||||
@uv run --project . python -m ruff format .
|
||||
@uv run --project . python -m ruff check --fix .
|
||||
|
||||
typecheck:
|
||||
@uv run --project . python -m basedpyright --level error src examples tests
|
||||
|
||||
test:
|
||||
@uv run --project . python -m pytest tests
|
||||
|
||||
update-examples:
|
||||
@uv run --project . python -m pytest --update-examples tests/docs/test_examples.py
|
||||
|
||||
docs:
|
||||
@uv run --project . --group docs python -m mkdocs build --no-strict
|
||||
|
||||
docs-serve:
|
||||
@uv run --project . --group docs python -m mkdocs serve --no-strict
|
||||
@ -1,7 +0,0 @@
|
||||
# Dify Agent
|
||||
|
||||
Agenton documentation lives in [`docs/agenton/guide/index.md`](docs/agenton/guide/index.md) and
|
||||
[`docs/agenton/api/index.md`](docs/agenton/api/index.md).
|
||||
|
||||
Dify Agent runtime documentation lives in [`docs/dify-agent/index.md`](docs/dify-agent/index.md).
|
||||
Build all docs with `make docs` from this directory.
|
||||
@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from mkdocs.config.defaults import MkDocsConfig
|
||||
from mkdocs.structure.files import Files
|
||||
from mkdocs.structure.pages import Page
|
||||
from snippets import inject_snippets
|
||||
|
||||
DOCS_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def on_page_markdown(markdown: str, page: Page, config: MkDocsConfig, files: Files) -> str:
|
||||
"""Inject repository snippets before MkDocs renders Markdown."""
|
||||
relative_path = DOCS_ROOT / page.file.src_uri
|
||||
return inject_snippets(markdown, relative_path.parent)
|
||||
@ -1,228 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
SNIPPET_DIRECTIVE_PATTERN = re.compile(r"^```snippet\s+\{[^}]+\}\s*(?:```|\n```)$", re.MULTILINE)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SnippetDirective:
|
||||
path: str
|
||||
title: str | None = None
|
||||
fragment: str | None = None
|
||||
highlight: str | None = None
|
||||
extra_attrs: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LineRange:
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
def intersection(self, ranges: list[LineRange]) -> list[LineRange]:
|
||||
intersections: list[LineRange] = []
|
||||
for line_range in ranges:
|
||||
start_line = max(self.start_line, line_range.start_line)
|
||||
end_line = min(self.end_line, line_range.end_line)
|
||||
if start_line < end_line:
|
||||
intersections.append(LineRange(start_line, end_line))
|
||||
return intersections
|
||||
|
||||
@staticmethod
|
||||
def merge(ranges: list[LineRange]) -> list[LineRange]:
|
||||
if not ranges:
|
||||
return []
|
||||
|
||||
merged: list[LineRange] = []
|
||||
for line_range in sorted(ranges, key=lambda item: item.start_line):
|
||||
if not merged or merged[-1].end_line < line_range.start_line:
|
||||
merged.append(line_range)
|
||||
else:
|
||||
previous = merged[-1]
|
||||
merged[-1] = LineRange(previous.start_line, max(previous.end_line, line_range.end_line))
|
||||
return merged
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RenderedSnippet:
|
||||
content: str
|
||||
highlights: list[LineRange]
|
||||
original_range: LineRange
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ParsedFile:
|
||||
lines: list[str]
|
||||
sections: dict[str, list[LineRange]]
|
||||
lines_mapping: dict[int, int]
|
||||
|
||||
def render(self, fragment_sections: list[str], highlight_sections: list[str]) -> RenderedSnippet:
|
||||
fragment_ranges: list[LineRange] = []
|
||||
if fragment_sections:
|
||||
for section_name in fragment_sections:
|
||||
fragment_ranges.extend(_section_ranges(self.sections, section_name))
|
||||
fragment_ranges = LineRange.merge(fragment_ranges)
|
||||
else:
|
||||
fragment_ranges = [LineRange(0, len(self.lines))]
|
||||
|
||||
highlight_ranges: list[LineRange] = []
|
||||
for section_name in highlight_sections:
|
||||
highlight_ranges.extend(_section_ranges(self.sections, section_name))
|
||||
highlight_ranges = LineRange.merge(highlight_ranges)
|
||||
|
||||
rendered_highlights: list[LineRange] = []
|
||||
rendered_lines: list[str] = []
|
||||
last_end_line = 0
|
||||
current_line = 0
|
||||
for fragment_range in fragment_ranges:
|
||||
if fragment_range.start_line > last_end_line:
|
||||
rendered_lines.append("..." if current_line == 0 else "\n...")
|
||||
current_line += 1
|
||||
|
||||
for highlight_range in fragment_range.intersection(highlight_ranges):
|
||||
rendered_highlights.append(
|
||||
LineRange(
|
||||
highlight_range.start_line - fragment_range.start_line + current_line,
|
||||
highlight_range.end_line - fragment_range.start_line + current_line,
|
||||
)
|
||||
)
|
||||
|
||||
for line_number in range(fragment_range.start_line, fragment_range.end_line):
|
||||
rendered_lines.append(self.lines[line_number])
|
||||
current_line += 1
|
||||
last_end_line = fragment_range.end_line
|
||||
|
||||
if last_end_line < len(self.lines):
|
||||
rendered_lines.append("\n...")
|
||||
|
||||
return RenderedSnippet(
|
||||
content="\n".join(rendered_lines),
|
||||
highlights=LineRange.merge(rendered_highlights),
|
||||
original_range=LineRange(
|
||||
self.lines_mapping[fragment_ranges[0].start_line],
|
||||
self.lines_mapping[fragment_ranges[-1].end_line - 1] + 1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def parse_snippet_directive(line: str) -> SnippetDirective | None:
|
||||
match = re.fullmatch(r"```snippet\s+\{([^}]+)\}\s*(?:```|\n```)", line.strip())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
attrs = {key: value for key, value in re.findall(r'(\w+)="([^"]*)"', match.group(1))}
|
||||
if "path" not in attrs:
|
||||
raise ValueError('Missing required key "path" in snippet directive')
|
||||
|
||||
extra_attrs = {key: value for key, value in attrs.items() if key not in {"path", "title", "fragment", "highlight"}}
|
||||
return SnippetDirective(
|
||||
path=attrs["path"],
|
||||
title=attrs.get("title"),
|
||||
fragment=attrs.get("fragment"),
|
||||
highlight=attrs.get("highlight"),
|
||||
extra_attrs=extra_attrs or None,
|
||||
)
|
||||
|
||||
|
||||
def parse_file_sections(file_path: Path) -> ParsedFile:
|
||||
input_lines = file_path.read_text(encoding="utf-8").splitlines()
|
||||
output_lines: list[str] = []
|
||||
lines_mapping: dict[int, int] = {}
|
||||
sections: dict[str, list[LineRange]] = {}
|
||||
section_starts: dict[str, int] = {}
|
||||
|
||||
output_line_number = 0
|
||||
for source_line_number, line in enumerate(input_lines):
|
||||
section_match = re.search(r'\s*(?:###|///)\s*\[([^]]+)]\s*$', line)
|
||||
if section_match is None:
|
||||
output_lines.append(line)
|
||||
lines_mapping[output_line_number] = source_line_number
|
||||
output_line_number += 1
|
||||
continue
|
||||
|
||||
line_before_marker = line[: section_match.start()]
|
||||
for section_name in section_match.group(1).split(","):
|
||||
section_name = section_name.strip()
|
||||
if section_name.startswith("/"):
|
||||
start_line = section_starts.pop(section_name[1:], None)
|
||||
if start_line is None:
|
||||
raise ValueError(f"Cannot end unstarted section {section_name!r} at {file_path}")
|
||||
end_line = output_line_number + 1 if line_before_marker else output_line_number
|
||||
sections.setdefault(section_name[1:], []).append(LineRange(start_line, end_line))
|
||||
else:
|
||||
if section_name in section_starts:
|
||||
raise ValueError(f"Cannot nest section {section_name!r} at {file_path}")
|
||||
section_starts[section_name] = output_line_number
|
||||
|
||||
if line_before_marker:
|
||||
output_lines.append(line_before_marker)
|
||||
lines_mapping[output_line_number] = source_line_number
|
||||
output_line_number += 1
|
||||
|
||||
if section_starts:
|
||||
raise ValueError(f"Some sections were not finished in {file_path}: {list(section_starts)}")
|
||||
|
||||
return ParsedFile(lines=output_lines, sections=sections, lines_mapping=lines_mapping)
|
||||
|
||||
|
||||
def format_highlight_lines(highlight_ranges: list[LineRange]) -> str:
|
||||
parts: list[str] = []
|
||||
for highlight_range in highlight_ranges:
|
||||
start_line = highlight_range.start_line + 1
|
||||
end_line = highlight_range.end_line
|
||||
parts.append(str(start_line) if start_line == end_line else f"{start_line}-{end_line}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def inject_snippets(markdown: str, relative_path_root: Path) -> str:
|
||||
def replace_snippet(match: re.Match[str]) -> str:
|
||||
directive = parse_snippet_directive(match.group(0))
|
||||
if directive is None:
|
||||
return match.group(0)
|
||||
|
||||
file_path = _resolve_snippet_path(directive.path, relative_path_root)
|
||||
parsed_file = parse_file_sections(file_path)
|
||||
rendered = parsed_file.render(
|
||||
directive.fragment.split() if directive.fragment else [],
|
||||
directive.highlight.split() if directive.highlight else [],
|
||||
)
|
||||
|
||||
attrs: list[str] = []
|
||||
title = directive.title or _default_title(file_path, rendered.original_range, bool(directive.fragment))
|
||||
if title:
|
||||
attrs.append(f'title="{title}"')
|
||||
if rendered.highlights:
|
||||
attrs.append(f'hl_lines="{format_highlight_lines(rendered.highlights)}"')
|
||||
if directive.extra_attrs:
|
||||
attrs.extend(f'{key}="{value}"' for key, value in directive.extra_attrs.items())
|
||||
|
||||
attrs_text = f" {{{' '.join(attrs)}}}" if attrs else ""
|
||||
file_extension = file_path.suffix.lstrip(".") or "text"
|
||||
return f"```{file_extension}{attrs_text}\n{rendered.content}\n```"
|
||||
|
||||
return SNIPPET_DIRECTIVE_PATTERN.sub(replace_snippet, markdown)
|
||||
|
||||
|
||||
def _section_ranges(sections: dict[str, list[LineRange]], section_name: str) -> list[LineRange]:
|
||||
if section_name not in sections:
|
||||
raise ValueError(f"Unrecognized snippet section {section_name!r}; expected one of {list(sections)}")
|
||||
return sections[section_name]
|
||||
|
||||
|
||||
def _resolve_snippet_path(path: str, relative_path_root: Path) -> Path:
|
||||
file_path = (REPO_ROOT / path[1:]).resolve() if path.startswith("/") else (relative_path_root / path).resolve()
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Snippet file {file_path} not found")
|
||||
if not file_path.is_relative_to(REPO_ROOT):
|
||||
raise ValueError(f"Snippet file {file_path} must be inside {REPO_ROOT}")
|
||||
return file_path
|
||||
|
||||
|
||||
def _default_title(file_path: Path, original_range: LineRange, has_fragment: bool) -> str:
|
||||
relative_path = file_path.relative_to(REPO_ROOT)
|
||||
if not has_fragment:
|
||||
return str(relative_path)
|
||||
return f"{relative_path} (L{original_range.start_line + 1}-L{original_range.end_line})"
|
||||
@ -1,183 +0,0 @@
|
||||
# Agenton API reference
|
||||
|
||||
This page summarizes the public Agenton API. Import paths are shown for the
|
||||
symbols commonly used by layer authors and compositor callers.
|
||||
|
||||
## Layers: `agenton.layers`
|
||||
|
||||
### `Layer[DepsT, PromptT, UserPromptT, ToolT, ConfigT, RuntimeStateT, RuntimeHandlesT]`
|
||||
|
||||
Framework-neutral base class for prompt/tool layers.
|
||||
|
||||
Class attributes:
|
||||
|
||||
- `type_id: str | None`: registry id for config-backed plugin layers.
|
||||
- `config_type: type[BaseModel]`: Pydantic schema for serialized layer config.
|
||||
- `runtime_state_type: type[BaseModel]`: Pydantic schema for snapshot-safe
|
||||
per-session state.
|
||||
- `runtime_handles_type: type[BaseModel]`: Pydantic schema for live runtime
|
||||
handles; use `arbitrary_types_allowed=True` for client/process objects.
|
||||
- `deps_type: type[LayerDeps]`: inferred from the layer generic base or declared
|
||||
explicitly.
|
||||
|
||||
Construction and dependency APIs:
|
||||
|
||||
- `from_config(config: ConfigT) -> Self`: create a layer from schema-validated
|
||||
config. The default implementation raises `TypeError`.
|
||||
- `dependency_names() -> frozenset[str]`: dependency fields declared by
|
||||
`deps_type`.
|
||||
- `bind_deps(deps: Mapping[str, Layer | None]) -> None`: bind graph dependencies.
|
||||
- `new_control(state=LifecycleState.NEW, runtime_state=None) -> LayerControl`: create
|
||||
a schema-validated per-session control.
|
||||
|
||||
Lifecycle hooks:
|
||||
|
||||
- `on_context_create(control)`
|
||||
- `on_context_resume(control)`
|
||||
- `on_context_suspend(control)`
|
||||
- `on_context_delete(control)`
|
||||
- `enter(control)` / `lifecycle_enter(control)`: async context manager entry
|
||||
surface. Override `enter()` only when a layer needs to wrap extra resources.
|
||||
|
||||
Prompt/tool authoring surfaces:
|
||||
|
||||
- `prefix_prompts -> Sequence[PromptT]`
|
||||
- `suffix_prompts -> Sequence[PromptT]`
|
||||
- `user_prompts -> Sequence[UserPromptT]`
|
||||
- `tools -> Sequence[ToolT]`
|
||||
|
||||
Aggregation adapters implemented by typed layer families:
|
||||
|
||||
- `wrap_prompt(prompt: PromptT) -> object`
|
||||
- `wrap_user_prompt(prompt: UserPromptT) -> object`
|
||||
- `wrap_tool(tool: ToolT) -> object`
|
||||
|
||||
### `LayerControl[RuntimeStateT, RuntimeHandlesT]`
|
||||
|
||||
Per-layer, per-session lifecycle control.
|
||||
|
||||
Fields:
|
||||
|
||||
- `state: LifecycleState`
|
||||
- `exit_intent: ExitIntent`
|
||||
- `runtime_state: RuntimeStateT`
|
||||
- `runtime_handles: RuntimeHandlesT`
|
||||
|
||||
Methods:
|
||||
|
||||
- `suspend_on_exit() -> None`
|
||||
- `delete_on_exit() -> None`
|
||||
|
||||
`runtime_state` is serialized in session snapshots. `runtime_handles` is never
|
||||
serialized and should be rehydrated from runtime state in resume hooks.
|
||||
|
||||
### Schema defaults and lifecycle enums
|
||||
|
||||
- `EmptyLayerConfig`
|
||||
- `EmptyRuntimeState`
|
||||
- `EmptyRuntimeHandles`
|
||||
- `LifecycleState`: `NEW`, `ACTIVE`, `SUSPENDED`, `CLOSED`
|
||||
- `ExitIntent`: `DELETE`, `SUSPEND`
|
||||
|
||||
### Typed layer families: `agenton.layers.types`
|
||||
|
||||
- `PlainLayer[DepsT, ConfigT, RuntimeStateT, RuntimeHandlesT]`
|
||||
- `PydanticAILayer[DepsT, AgentDepsT, ConfigT, RuntimeStateT, RuntimeHandlesT]`
|
||||
|
||||
Tagged aggregate item types:
|
||||
|
||||
- `PlainPromptType`, `PlainUserPromptType`, `PlainToolType`
|
||||
- `PydanticAIPromptType`, `PydanticAIUserPromptType`, `PydanticAIToolType`
|
||||
- `AllPromptTypes`, `AllUserPromptTypes`, `AllToolTypes`
|
||||
|
||||
## Compositor: `agenton.compositor`
|
||||
|
||||
### Config models
|
||||
|
||||
- `LayerNodeConfig`: `name`, `type`, `config`, `deps`, `metadata`
|
||||
- `CompositorConfig`: `schema_version`, `layers`
|
||||
|
||||
Config nodes are pure serializable graph input. Use live instances for Python
|
||||
objects and callables.
|
||||
|
||||
### Registry
|
||||
|
||||
`LayerRegistry` manually registers config-backed layer classes.
|
||||
|
||||
- `register_layer(layer_type, type_id=None) -> None`
|
||||
- `resolve(type_id) -> LayerDescriptor`
|
||||
- `descriptors() -> Mapping[str, LayerDescriptor]`
|
||||
|
||||
`LayerDescriptor` exposes `type_id`, `layer_type`, `config_type`,
|
||||
`runtime_state_type`, and `runtime_handles_type`.
|
||||
|
||||
### Builder
|
||||
|
||||
`CompositorBuilder(registry)` mixes config-backed nodes and live instances.
|
||||
|
||||
- `add_config(config) -> Self`
|
||||
- `add_config_layer(name, type, config=None, deps=None) -> Self`
|
||||
- `add_instance(name, layer, deps=None) -> Self`
|
||||
- `build(prompt_transformer=None, user_prompt_transformer=None, tool_transformer=None) -> Compositor`
|
||||
|
||||
### Compositor
|
||||
|
||||
`Compositor[PromptT, ToolT, LayerPromptT, LayerToolT, UserPromptT, LayerUserPromptT]`
|
||||
owns the ordered layer graph.
|
||||
|
||||
Construction:
|
||||
|
||||
- `Compositor(layers=..., deps_name_mapping=..., ...)`
|
||||
- `Compositor.from_config(conf, registry=..., ...)`
|
||||
|
||||
Aggregation properties:
|
||||
|
||||
- `prompts -> list[PromptT]`: prefix prompts in layer order, suffix prompts in
|
||||
reverse layer order, then optional `prompt_transformer`.
|
||||
- `user_prompts -> list[UserPromptT]`: user prompts in layer order, then optional
|
||||
`user_prompt_transformer`.
|
||||
- `tools -> list[ToolT]`: tools in layer order, then optional `tool_transformer`.
|
||||
|
||||
Session APIs:
|
||||
|
||||
- `new_session() -> CompositorSession`
|
||||
- `enter(session=None) -> AsyncIterator[CompositorSession]`
|
||||
- `snapshot_session(session) -> CompositorSessionSnapshot`
|
||||
- `session_from_snapshot(snapshot) -> CompositorSession`
|
||||
|
||||
### Sessions and snapshots
|
||||
|
||||
`CompositorSession` owns ordered layer controls.
|
||||
|
||||
- `suspend_on_exit() -> None`
|
||||
- `delete_on_exit() -> None`
|
||||
- `layer(name) -> LayerControl`
|
||||
|
||||
Snapshot models:
|
||||
|
||||
- `LayerSessionSnapshot`: `name`, `state`, `runtime_state`
|
||||
- `CompositorSessionSnapshot`: `schema_version`, `layers`
|
||||
|
||||
Snapshots reject active sessions and exclude `runtime_handles` and `exit_intent`.
|
||||
|
||||
## Collection layers and transformers
|
||||
|
||||
### Plain layers: `agenton_collections.layers.plain`
|
||||
|
||||
- `PromptLayer`: config-backed layer with `PromptLayerConfig(prefix, user,
|
||||
suffix)` and `type_id = "plain.prompt"`.
|
||||
- `ObjectLayer`: instance-only layer for Python objects.
|
||||
- `ToolsLayer`: instance-only layer for callables.
|
||||
- `DynamicToolsLayer`: instance-only layer for object-bound callables.
|
||||
|
||||
### Pydantic AI bridge
|
||||
|
||||
`agenton_collections.layers.pydantic_ai.PydanticAIBridgeLayer` exposes
|
||||
pydantic-ai system prompts, user prompts, and tools while depending on an
|
||||
`ObjectLayer` for `RunContext.deps`.
|
||||
|
||||
`agenton_collections.transformers.PYDANTIC_AI_TRANSFORMERS` provides:
|
||||
|
||||
- `prompt_transformer`: maps `compositor.prompts` to pydantic-ai system prompt functions.
|
||||
- `user_prompt_transformer`: maps `compositor.user_prompts` to pydantic-ai `UserContent`.
|
||||
- `tool_transformer`: maps `compositor.tools` to pydantic-ai tools.
|
||||
@ -1,19 +0,0 @@
|
||||
# Agenton examples
|
||||
|
||||
The Agenton examples live under `examples/agenton/agenton_examples` and are kept
|
||||
importable as a package so documentation can reference real source files.
|
||||
|
||||
## Basics
|
||||
|
||||
```snippet {path="/examples/agenton/agenton_examples/basics.py"}
|
||||
```
|
||||
|
||||
## Pydantic AI bridge
|
||||
|
||||
```snippet {path="/examples/agenton/agenton_examples/pydantic_ai_bridge.py"}
|
||||
```
|
||||
|
||||
## Session snapshots
|
||||
|
||||
```snippet {path="/examples/agenton/agenton_examples/session_snapshot.py"}
|
||||
```
|
||||
@ -1,117 +0,0 @@
|
||||
# Agenton user guide
|
||||
|
||||
Agenton composes shared `Layer` instances into a named graph. Treat layer
|
||||
instances as reusable capability definitions: config and dependency declarations
|
||||
belong on the layer class or instance, while per-session runtime values belong
|
||||
on the `LayerControl` created for that layer in a `CompositorSession`.
|
||||
|
||||
## Config, runtime state, and runtime handles
|
||||
|
||||
- **Config** is serializable graph input. Config-constructible layers declare a
|
||||
`type_id` and a Pydantic `config_type`; builders validate node config before
|
||||
calling `Layer.from_config(validated_config)`.
|
||||
- **Runtime state** is serializable per-layer/per-session state. Layers declare a
|
||||
Pydantic `runtime_state_type`; session snapshots persist this model with
|
||||
`model_dump(mode="json")`.
|
||||
- **Runtime handles** are live Python objects such as clients, open files, or
|
||||
process handles. Layers declare a Pydantic `runtime_handles_type` with
|
||||
`arbitrary_types_allowed=True`. Handles are never serialized; resume hooks
|
||||
should rehydrate them from runtime state.
|
||||
|
||||
## Define a config-backed layer
|
||||
|
||||
Use a Pydantic model for config and pass it through the typed layer family so
|
||||
`Layer.__init_subclass__` can infer the schema:
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
class GreetingConfig(BaseModel):
|
||||
prefix: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GreetingLayer(PlainLayer[NoLayerDeps, GreetingConfig]):
|
||||
type_id = "example.greeting"
|
||||
prefix: str
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: GreetingConfig) -> Self:
|
||||
return cls(prefix=config.prefix)
|
||||
|
||||
@property
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
return [self.prefix]
|
||||
```
|
||||
|
||||
Omitted schema slots default to `EmptyLayerConfig`, `EmptyRuntimeState`, and
|
||||
`EmptyRuntimeHandles`. Lifecycle hooks can annotate controls as
|
||||
`LayerControl[MyState, MyHandles]` to get static checking and IDE completion for
|
||||
runtime state and handles.
|
||||
|
||||
## Register layers and build a compositor
|
||||
|
||||
Register config-constructible layers manually:
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer) # uses PromptLayer.type_id == "plain.prompt"
|
||||
```
|
||||
|
||||
Use `CompositorBuilder` to mix serializable config nodes with live instances:
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {"prefix": "Hi", "user": "Answer with examples."},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
.add_instance(name="profile", layer=ObjectLayer(profile))
|
||||
.build()
|
||||
)
|
||||
```
|
||||
|
||||
Use `.add_instance()` for layers that require Python objects or callables, such
|
||||
as `ObjectLayer`, `ToolsLayer`, and dynamic tool layers.
|
||||
|
||||
## System prompts and user prompts
|
||||
|
||||
Layers expose three prompt surfaces:
|
||||
|
||||
- `prefix_prompts`: system prompt fragments collected in layer order.
|
||||
- `suffix_prompts`: system prompt fragments collected in reverse layer order.
|
||||
- `user_prompts`: user-message fragments collected in layer order.
|
||||
|
||||
`PromptLayer` accepts `prefix`, `user`, and `suffix` config fields. For
|
||||
pydantic-ai, `PYDANTIC_AI_TRANSFORMERS` maps `compositor.prompts` to system
|
||||
prompt functions and `compositor.user_prompts` to values suitable for
|
||||
`Agent.run(user_prompt=...)`.
|
||||
|
||||
## Session snapshot and restore
|
||||
|
||||
`Compositor.snapshot_session(session)` serializes non-active sessions, including
|
||||
layer lifecycle state and runtime state. It rejects active sessions because live
|
||||
handles cannot be snapshotted safely. Restore with
|
||||
`Compositor.session_from_snapshot(snapshot)`; restored controls validate runtime
|
||||
state with each layer schema and initialize empty runtime handles. Suspended
|
||||
sessions resume through `on_context_resume`, where handles should be hydrated
|
||||
from the restored runtime state.
|
||||
|
||||
Create sessions with `Compositor.new_session()` or
|
||||
`Compositor.session_from_snapshot()`. `Compositor.enter()` validates that every
|
||||
session control uses the target layer's runtime state and handle schemas before
|
||||
any lifecycle hook runs.
|
||||
|
||||
See also:
|
||||
|
||||
- `examples/agenton/agenton_examples/basics.py`
|
||||
- `examples/agenton/agenton_examples/pydantic_ai_bridge.py`
|
||||
- `examples/agenton/agenton_examples/session_snapshot.py`
|
||||
@ -1,6 +0,0 @@
|
||||
# Agenton documentation
|
||||
|
||||
- [User guide](guide/index.md) explains how to compose layers, register config-backed
|
||||
plugins, use system/user prompts, and snapshot sessions.
|
||||
- [API reference](api/index.md) lists the public Agenton classes, methods, and extension
|
||||
points.
|
||||
@ -1,186 +0,0 @@
|
||||
# Dify Agent Run API
|
||||
|
||||
The Dify Agent API exposes asynchronous agent runs backed by Agenton compositor
|
||||
configuration, Pydantic AI runtime execution, Redis run records, and per-run Redis
|
||||
Streams event logs. The FastAPI application lives at
|
||||
`dify-agent/src/dify_agent/server/app.py`.
|
||||
|
||||
## Input model
|
||||
|
||||
Create-run requests accept a `CompositorConfig` and an optional
|
||||
`CompositorSessionSnapshot`. There is **no top-level `user_prompt` field**.
|
||||
User input must be supplied by Agenton layers. In the MVP server, the safe
|
||||
config-constructible layer registry includes `plain.prompt`; its `config.user`
|
||||
field becomes `Compositor.user_prompts` and is passed to Pydantic AI as the run
|
||||
input.
|
||||
|
||||
Blank user input is rejected. A request with no user prompt, an empty string, or
|
||||
only whitespace strings such as `"user": ["", " "]` returns `422` before a run
|
||||
record is created.
|
||||
|
||||
The server does not implement a Pydantic AI history layer. Resumable Agenton
|
||||
state is represented only by `session_snapshot`.
|
||||
|
||||
## Create a run
|
||||
|
||||
```http
|
||||
POST /runs
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
Request:
|
||||
|
||||
```json
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "You are a concise assistant.",
|
||||
"user": "Say hello from the Dify Agent API."
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"session_snapshot": null,
|
||||
"agent_profile": {
|
||||
"provider": "test",
|
||||
"output_text": "Hello from the TestModel."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Response (`202 Accepted`):
|
||||
|
||||
```json
|
||||
{
|
||||
"run_id": "4a7f9a98-5c55-48d0-8f3e-87ef2cf81234",
|
||||
"status": "running"
|
||||
}
|
||||
```
|
||||
|
||||
The server persists the run record and schedules execution immediately in the
|
||||
same FastAPI process. Redis is not used as a job queue. Run records and per-run
|
||||
event streams expire after `DIFY_AGENT_RUN_RETENTION_SECONDS`, which defaults to
|
||||
`259200` seconds (3 days).
|
||||
|
||||
`agent_profile.provider` currently supports the credential-free `test` profile.
|
||||
|
||||
Validation error example (`422`):
|
||||
|
||||
```json
|
||||
{
|
||||
"detail": "compositor.user_prompts must not be empty"
|
||||
}
|
||||
```
|
||||
|
||||
## Get run status
|
||||
|
||||
```http
|
||||
GET /runs/{run_id}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"run_id": "4a7f9a98-5c55-48d0-8f3e-87ef2cf81234",
|
||||
"status": "succeeded",
|
||||
"created_at": "2026-05-08T12:00:00Z",
|
||||
"updated_at": "2026-05-08T12:00:02Z",
|
||||
"error": null
|
||||
}
|
||||
```
|
||||
|
||||
Status values are:
|
||||
|
||||
- `running`
|
||||
- `succeeded`
|
||||
- `failed`
|
||||
|
||||
Unknown or expired run ids return `404` with `"run not found"`.
|
||||
|
||||
## Poll events
|
||||
|
||||
```http
|
||||
GET /runs/{run_id}/events?after=0-0&limit=100
|
||||
```
|
||||
|
||||
Cursor values are Redis Stream IDs. Use `after=0-0` to read from the beginning.
|
||||
The response includes `next_cursor`; pass it as the next `after` value to continue
|
||||
polling.
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"run_id": "4a7f9a98-5c55-48d0-8f3e-87ef2cf81234",
|
||||
"events": [
|
||||
{
|
||||
"id": "1715170000000-0",
|
||||
"run_id": "4a7f9a98-5c55-48d0-8f3e-87ef2cf81234",
|
||||
"type": "run_started",
|
||||
"data": {},
|
||||
"created_at": "2026-05-08T12:00:00Z"
|
||||
}
|
||||
],
|
||||
"next_cursor": "1715170000000-0"
|
||||
}
|
||||
```
|
||||
|
||||
## Stream events with SSE
|
||||
|
||||
```http
|
||||
GET /runs/{run_id}/events/sse
|
||||
```
|
||||
|
||||
SSE frames use the run event id as `id`, the event type as `event`, and the full
|
||||
`RunEvent` JSON object as `data`:
|
||||
|
||||
```text
|
||||
id: 1715170000000-0
|
||||
event: run_started
|
||||
data: {"id":"1715170000000-0","run_id":"...","type":"run_started","data":{},"created_at":"..."}
|
||||
|
||||
```
|
||||
|
||||
Replay can start from a cursor with either:
|
||||
|
||||
- `GET /runs/{run_id}/events/sse?after=1715170000000-0`
|
||||
- `Last-Event-ID: 1715170000000-0`
|
||||
|
||||
If both are provided, the `after` query parameter takes precedence.
|
||||
|
||||
## Event types and order
|
||||
|
||||
A normal successful run emits:
|
||||
|
||||
1. `run_started`
|
||||
2. zero or more `pydantic_ai_event`
|
||||
3. `agent_output`
|
||||
4. `session_snapshot`
|
||||
5. `run_succeeded`
|
||||
|
||||
A failed run emits:
|
||||
|
||||
1. `run_started`
|
||||
2. zero or more `pydantic_ai_event`
|
||||
3. `run_failed`
|
||||
|
||||
Each event keeps the same envelope shape and has typed `data`: `run_started` and
|
||||
`run_succeeded` use `{}`, `pydantic_ai_event` uses Pydantic AI's
|
||||
`AgentStreamEvent` union, `agent_output` uses `{ "output": string }`,
|
||||
`session_snapshot` uses `CompositorSessionSnapshot`, and `run_failed` uses
|
||||
`{ "error": string, "reason": string | null }`. The session snapshot can be sent
|
||||
as `session_snapshot` in a later create-run request with the same compositor layer
|
||||
names and order.
|
||||
|
||||
## Consumer examples
|
||||
|
||||
See:
|
||||
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py` for cursor polling
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py` for SSE consumption
|
||||
@ -1,20 +0,0 @@
|
||||
# Dify Agent examples
|
||||
|
||||
These examples live under `examples/dify_agent/dify_agent_examples`. They are
|
||||
separated from Agenton examples because they depend on Dify Agent runtime services
|
||||
such as the FastAPI server, Redis, or the plugin daemon.
|
||||
|
||||
## Run a Dify plugin-daemon backed model
|
||||
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_pydantic_ai_agent.py"}
|
||||
```
|
||||
|
||||
## Poll run events
|
||||
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_consumer.py"}
|
||||
```
|
||||
|
||||
## Stream run events with SSE
|
||||
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py"}
|
||||
```
|
||||
@ -1,127 +0,0 @@
|
||||
# Operating the Dify Agent Run Server
|
||||
|
||||
This guide describes how to run the MVP Dify Agent API server. The server is
|
||||
implemented in `dify-agent/src/dify_agent/server/app.py` and uses Redis for run
|
||||
records and per-run event streams only.
|
||||
|
||||
## Default local startup
|
||||
|
||||
Start Redis, then run one FastAPI/uvicorn process:
|
||||
|
||||
```bash
|
||||
uv run --project dify-agent uvicorn dify_agent.server.app:app --reload
|
||||
```
|
||||
|
||||
By default, the FastAPI lifespan creates both:
|
||||
|
||||
- one Redis-backed run store used by HTTP routes
|
||||
- one process-local scheduler that starts background `asyncio` run tasks
|
||||
|
||||
This means local development needs one uvicorn process plus Redis. Run execution
|
||||
still happens outside request handlers, so client disconnects do not cancel the
|
||||
agent run.
|
||||
|
||||
## Configuration
|
||||
|
||||
`ServerSettings` loads environment variables with the `DIFY_AGENT_` prefix. It
|
||||
also reads `.env` and `dify-agent/.env` when present.
|
||||
|
||||
| Environment variable | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| `DIFY_AGENT_REDIS_URL` | `redis://localhost:6379/0` | Redis connection URL. |
|
||||
| `DIFY_AGENT_REDIS_PREFIX` | `dify-agent` | Prefix for Redis record and event keys. |
|
||||
| `DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` | `30` | Seconds to wait for active local runs during graceful shutdown before cancellation. |
|
||||
| `DIFY_AGENT_RUN_RETENTION_SECONDS` | `259200` | Seconds to retain Redis run records and per-run event streams; defaults to 3 days. |
|
||||
|
||||
Example `.env`:
|
||||
|
||||
```env
|
||||
DIFY_AGENT_REDIS_URL=redis://localhost:6379/0
|
||||
DIFY_AGENT_REDIS_PREFIX=dify-agent-dev
|
||||
DIFY_AGENT_SHUTDOWN_GRACE_SECONDS=30
|
||||
DIFY_AGENT_RUN_RETENTION_SECONDS=259200
|
||||
```
|
||||
|
||||
Run records and event streams use the same retention. Status writes refresh the
|
||||
record TTL, and event writes refresh both the stream TTL and the corresponding
|
||||
record TTL so active runs that keep producing events remain observable.
|
||||
|
||||
## Scheduling and shutdown semantics
|
||||
|
||||
`POST /runs` validates the compositor, persists a `running` run record, and starts
|
||||
an `asyncio` task in the same process. There is no Redis job stream, consumer
|
||||
group, pending reclaim, or automatic retry layer.
|
||||
|
||||
During FastAPI shutdown the scheduler rejects new runs, waits up to
|
||||
`DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` for active tasks, then cancels remaining tasks
|
||||
and best-effort appends a `run_failed` event plus failed status. A hard process
|
||||
crash can still leave active runs stuck as `running`; there is no in-service
|
||||
recovery or worker handoff.
|
||||
|
||||
Horizontal scaling is possible by running multiple API processes against the same
|
||||
Redis prefix, but each process executes only the runs it accepted. Redis provides
|
||||
shared status/event visibility, not load balancing or queued-job recovery.
|
||||
|
||||
## Run inputs and session snapshots
|
||||
|
||||
The API does not accept a top-level `user_prompt`. Submit a `CompositorConfig`
|
||||
whose Agenton layers provide user input. With the MVP registry, use
|
||||
`plain.prompt` and its `config.user` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "You are concise.",
|
||||
"user": "Summarize the current state."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`config.user` can be a string or a list of strings. Empty or whitespace-only
|
||||
effective prompts are rejected during create-run validation before the run is
|
||||
persisted or scheduled.
|
||||
|
||||
There is no Pydantic AI history layer. To resume Agenton layer state, pass the
|
||||
`session_snapshot` emitted by a previous run together with a compositor that has
|
||||
the same layer names and order.
|
||||
|
||||
## Observing runs
|
||||
|
||||
Use the HTTP status endpoint for coarse state and the event endpoints for detailed
|
||||
progress:
|
||||
|
||||
- `POST /runs` creates a running run and schedules it locally.
|
||||
- `GET /runs/{run_id}` returns `running`, `succeeded`, or `failed`.
|
||||
- `GET /runs/{run_id}/events` polls the Redis Stream event log with `after` and
|
||||
`next_cursor` cursors.
|
||||
- `GET /runs/{run_id}/events/sse` replays and streams events over SSE. The SSE
|
||||
`id` is the event Redis Stream ID. `after` query cursors take precedence over
|
||||
`Last-Event-ID` headers.
|
||||
|
||||
Successful runs emit `run_started`, zero or more `pydantic_ai_event`,
|
||||
`agent_output`, `session_snapshot`, and `run_succeeded`. Failed runs end with
|
||||
`run_failed`. Event envelopes retain `id`, `run_id`, `type`, `data`, and
|
||||
`created_at`; `data` is typed per event type, including Pydantic AI's
|
||||
`AgentStreamEvent` payload for `pydantic_ai_event` and `CompositorSessionSnapshot`
|
||||
for `session_snapshot`.
|
||||
|
||||
## Examples
|
||||
|
||||
The repository includes simple consumers that print observed output/events:
|
||||
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py`
|
||||
creates a run and polls events.
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py`
|
||||
consumes raw SSE frames for an existing run id.
|
||||
|
||||
Both examples use the credential-free Pydantic AI `TestModel` profile; they still
|
||||
require Redis and the API server.
|
||||
@ -1,8 +0,0 @@
|
||||
# Dify Agent runtime
|
||||
|
||||
Dify Agent hosts Agenton-composed Pydantic AI runs behind a FastAPI API. Its
|
||||
source code stays under `src/dify_agent`, while framework-neutral Agenton code
|
||||
stays under `src/agenton` and `src/agenton_collections`.
|
||||
|
||||
See the [operations guide](guide/index.md) for local server behavior and the
|
||||
[run API](api/index.md) for request and event schemas.
|
||||
@ -1,11 +0,0 @@
|
||||
# Dify Agent
|
||||
|
||||
This documentation is split by ownership boundary:
|
||||
|
||||
- [Agenton](agenton/index.md) covers the framework-neutral layer compositor and reusable
|
||||
collection layers.
|
||||
- [Dify Agent](dify-agent/index.md) covers the Dify runtime, HTTP API, Redis-backed run
|
||||
storage, and server examples.
|
||||
|
||||
The split mirrors the source tree so Agenton documentation and examples can be
|
||||
moved together if Agenton is published separately later.
|
||||
@ -1 +0,0 @@
|
||||
"""Runnable Agenton examples kept separate from Dify Agent runtime examples."""
|
||||
@ -1,50 +0,0 @@
|
||||
"""Small CLI for listing or copying Agenton examples."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
EXAMPLE_MODULES = (
|
||||
"basics",
|
||||
"pydantic_ai_bridge",
|
||||
"session_snapshot",
|
||||
)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="agenton_examples",
|
||||
description="List or copy Agenton examples.",
|
||||
)
|
||||
parser.add_argument("--copy-to", metavar="DEST", help="Copy example files to a new directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
examples_dir = Path(__file__).parent
|
||||
if args.copy_to:
|
||||
copy_to(examples_dir, Path(args.copy_to))
|
||||
return
|
||||
|
||||
for module_name in EXAMPLE_MODULES:
|
||||
print(f"python -m agenton_examples.{module_name}")
|
||||
|
||||
|
||||
def copy_to(examples_dir: Path, destination: Path) -> None:
|
||||
if destination.exists():
|
||||
print(f'Error: destination path "{destination}" already exists', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
destination.mkdir(parents=True)
|
||||
copied = 0
|
||||
for source in examples_dir.glob("*.py"):
|
||||
if source.name == "__init__.py":
|
||||
continue
|
||||
shutil.copy2(source, destination / source.name)
|
||||
copied += 1
|
||||
print(f'Copied {copied} Agenton example files to "{destination}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@ -1,136 +0,0 @@
|
||||
"""Run with: uv run --project dify-agent python -m agenton_examples.basics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import signature
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import CompositorBuilder, LayerRegistry
|
||||
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
|
||||
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, PromptLayer, ToolsLayer, with_object
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentProfile:
|
||||
name: str
|
||||
audience: str
|
||||
tone: str
|
||||
|
||||
|
||||
class ProfilePromptDeps(LayerDeps):
|
||||
profile: ObjectLayer[AgentProfile] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProfilePromptLayer(PlainLayer[ProfilePromptDeps]):
|
||||
@property
|
||||
@override
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
profile = self.deps.profile.value
|
||||
return [
|
||||
f"You are {profile.name}, writing for {profile.audience}.",
|
||||
f"Keep the tone {profile.tone}.",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraceLayer(PlainLayer[NoLayerDeps]):
|
||||
events: list[str] = field(default_factory=list)
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
self.events.append("create")
|
||||
|
||||
@override
|
||||
async def on_context_suspend(self, control: LayerControl) -> None:
|
||||
self.events.append("suspend")
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
self.events.append("resume")
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
self.events.append("delete")
|
||||
|
||||
|
||||
def count_words(text: str) -> int:
|
||||
return len(text.split())
|
||||
|
||||
|
||||
@with_object(AgentProfile)
|
||||
def write_tagline(profile: AgentProfile, topic: str) -> str:
|
||||
return f"{profile.name}: {topic} for {profile.audience}, in a {profile.tone} voice."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
profile = AgentProfile(
|
||||
name="Agenton Assistant",
|
||||
audience="engineers composing agent capabilities",
|
||||
tone="precise and friendly",
|
||||
)
|
||||
trace = TraceLayer()
|
||||
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use config dicts for serializable layers.",
|
||||
"user": "Explain how the composed agent should use its layers.",
|
||||
"suffix": "Before finalizing, make the result easy to scan.",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "extra_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use constructed instances for objects, local code, and callables.",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
|
||||
.add_instance(name="profile_prompt", layer=ProfilePromptLayer())
|
||||
.add_instance(name="tools", layer=ToolsLayer(tool_entries=(count_words,)))
|
||||
.add_instance(
|
||||
name="dynamic_tools",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)),
|
||||
)
|
||||
.add_instance(name="trace", layer=trace)
|
||||
.build()
|
||||
)
|
||||
|
||||
print("Prompts:")
|
||||
for prompt in compositor.prompts:
|
||||
print(f"- {prompt.value}")
|
||||
|
||||
print("\nUser prompts:")
|
||||
for prompt in compositor.user_prompts:
|
||||
print(f"- {prompt.value}")
|
||||
|
||||
print("\nTools:")
|
||||
for tool in compositor.tools:
|
||||
print(f"- {tool.value.__name__}{signature(tool.value)}")
|
||||
print([tool.value("layer composition") for tool in compositor.tools])
|
||||
|
||||
async with compositor.enter() as lifecycle_control:
|
||||
lifecycle_control.suspend_on_exit()
|
||||
async with compositor.enter(lifecycle_control):
|
||||
pass
|
||||
print("\nLifecycle:", trace.events)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,119 +0,0 @@
|
||||
"""Run with: uv run --project dify-agent python -m agenton_examples.pydantic_ai_bridge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic_ai.messages import BuiltinToolCallPart, ModelMessage, ToolCallPart
|
||||
from pydantic_ai.models.openai import OpenAIChatModel # pyright: ignore[reportDeprecated]
|
||||
from pydantic_ai.models.test import TestModel
|
||||
|
||||
from agenton.compositor import CompositorBuilder, LayerRegistry
|
||||
from agenton_collections.layers.plain import ObjectLayer, PromptLayer, ToolsLayer
|
||||
from agenton_collections.layers.pydantic_ai import PydanticAIBridgeLayer
|
||||
from agenton_collections.transformers import PYDANTIC_AI_TRANSFORMERS
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentProfile:
|
||||
name: str
|
||||
audience: str
|
||||
tone: str
|
||||
|
||||
|
||||
def count_words(text: str) -> int:
|
||||
return len(text.split())
|
||||
|
||||
|
||||
def profile_prompt(ctx: RunContext[AgentProfile]) -> str:
|
||||
profile = ctx.deps
|
||||
return f"You are {profile.name}, helping {profile.audience}."
|
||||
|
||||
|
||||
def tone_prompt(ctx: RunContext[AgentProfile]) -> str:
|
||||
return f"Keep responses {ctx.deps.tone}."
|
||||
|
||||
|
||||
def write_tagline(ctx: RunContext[AgentProfile], topic: str) -> str:
|
||||
profile = ctx.deps
|
||||
return f"{profile.name}: {topic} for {profile.audience}, in a {profile.tone} voice."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
profile = AgentProfile(
|
||||
name="Agenton Assistant",
|
||||
audience="engineers composing agent capabilities",
|
||||
tone="precise and friendly",
|
||||
)
|
||||
pydantic_ai_bridge = PydanticAIBridgeLayer[AgentProfile](
|
||||
prefix=("Prefer concrete details.", profile_prompt, tone_prompt),
|
||||
user="Use the tools for 'layer composition'.",
|
||||
tool_entries=(write_tagline,),
|
||||
)
|
||||
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use the available tools before answering.",
|
||||
"suffix": "Return concise, inspectable output.",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
|
||||
.add_instance(name="plain_tools", layer=ToolsLayer(tool_entries=(count_words,)))
|
||||
.add_instance(
|
||||
name="pydantic_ai_bridge",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=pydantic_ai_bridge,
|
||||
)
|
||||
.build(**PYDANTIC_AI_TRANSFORMERS)
|
||||
)
|
||||
|
||||
async with compositor.enter():
|
||||
model = (
|
||||
OpenAIChatModel("gpt-5.5") # pyright: ignore[reportDeprecated]
|
||||
if os.getenv("OPENAI_API_KEY")
|
||||
else TestModel()
|
||||
)
|
||||
agent = Agent[AgentProfile](
|
||||
model=model,
|
||||
deps_type=AgentProfile,
|
||||
tools=compositor.tools,
|
||||
)
|
||||
for prompt in compositor.prompts:
|
||||
_ = agent.system_prompt(prompt)
|
||||
|
||||
result = await agent.run(compositor.user_prompts, deps=pydantic_ai_bridge.run_deps)
|
||||
|
||||
for line in _format_messages(result.all_messages()):
|
||||
print(line)
|
||||
|
||||
|
||||
def _format_messages(messages: list[ModelMessage]) -> list[str]:
|
||||
lines: list[str] = []
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if isinstance(part, ToolCallPart | BuiltinToolCallPart):
|
||||
args = json.dumps(part.args, ensure_ascii=False)
|
||||
lines.append(f"{type(part).__name__}: {part.tool_name}({args})")
|
||||
else:
|
||||
lines.append(f"{type(part).__name__}: {part.content}")
|
||||
return lines
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,72 +0,0 @@
|
||||
"""Run with: uv run --project dify-agent python -m agenton_examples.session_snapshot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor
|
||||
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
|
||||
|
||||
|
||||
class ConnectionState(BaseModel):
|
||||
connection_id: str = "demo-connection"
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class ConnectionHandle:
|
||||
def __init__(self, connection_id: str) -> None:
|
||||
self.connection_id = connection_id
|
||||
|
||||
|
||||
class ConnectionHandles(BaseModel):
|
||||
connection: ConnectionHandle | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectionLayer(PlainLayer[NoLayerDeps]):
|
||||
runtime_state_type: ClassVar[type[BaseModel]] = ConnectionState
|
||||
runtime_handles_type: ClassVar[type[BaseModel]] = ConnectionHandles
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
assert isinstance(control.runtime_state, ConnectionState)
|
||||
assert isinstance(control.runtime_handles, ConnectionHandles)
|
||||
control.runtime_handles.connection = ConnectionHandle(control.runtime_state.connection_id)
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
assert isinstance(control.runtime_state, ConnectionState)
|
||||
assert isinstance(control.runtime_handles, ConnectionHandles)
|
||||
control.runtime_handles.connection = ConnectionHandle(f"restored:{control.runtime_state.connection_id}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("connection", ConnectionLayer())])
|
||||
)
|
||||
session = compositor.new_session()
|
||||
async with compositor.enter(session) as active_session:
|
||||
active_session.suspend_on_exit()
|
||||
|
||||
snapshot = compositor.snapshot_session(session)
|
||||
print("Snapshot:", snapshot.model_dump(mode="json"))
|
||||
|
||||
restored = compositor.session_from_snapshot(snapshot)
|
||||
async with compositor.enter(restored):
|
||||
handles = restored.layer("connection").runtime_handles
|
||||
assert isinstance(handles, ConnectionHandles)
|
||||
assert handles.connection is not None
|
||||
print("Rehydrated handle:", handles.connection.connection_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1 +0,0 @@
|
||||
"""Runnable Dify Agent runtime examples kept separate from Agenton examples."""
|
||||
@ -1,50 +0,0 @@
|
||||
"""Small CLI for listing or copying Dify Agent examples."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
EXAMPLE_MODULES = (
|
||||
"run_pydantic_ai_agent",
|
||||
"run_server_consumer",
|
||||
"run_server_sse_consumer",
|
||||
)
|
||||
|
||||
|
||||
def cli() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="dify_agent_examples",
|
||||
description="List or copy Dify Agent runtime examples.",
|
||||
)
|
||||
parser.add_argument("--copy-to", metavar="DEST", help="Copy example files to a new directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
examples_dir = Path(__file__).parent
|
||||
if args.copy_to:
|
||||
copy_to(examples_dir, Path(args.copy_to))
|
||||
return
|
||||
|
||||
for module_name in EXAMPLE_MODULES:
|
||||
print(f"python -m dify_agent_examples.{module_name}")
|
||||
|
||||
|
||||
def copy_to(examples_dir: Path, destination: Path) -> None:
|
||||
if destination.exists():
|
||||
print(f'Error: destination path "{destination}" already exists', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
destination.mkdir(parents=True)
|
||||
copied = 0
|
||||
for source in examples_dir.glob("*.py"):
|
||||
if source.name == "__init__.py":
|
||||
continue
|
||||
shutil.copy2(source, destination / source.name)
|
||||
copied += 1
|
||||
print(f'Copied {copied} Dify Agent example files to "{destination}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@ -1,78 +0,0 @@
|
||||
"""Run a Pydantic AI agent through the Dify plugin-daemon adapter.
|
||||
|
||||
Prerequisites:
|
||||
- Start the plugin daemon from `dify-aio/dify/docker/docker-compose.middleware.yaml`.
|
||||
- Run the Dify API with `dify-aio/dify/api/.env` so the daemon can resolve tenants/plugins.
|
||||
- Fill `dify-agent/.env` with a real tenant, plugin, provider, model, and provider credentials.
|
||||
|
||||
Example:
|
||||
uv run --project dify-agent python -m dify_agent_examples.run_pydantic_ai_agent
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic_ai import Agent
|
||||
|
||||
from dify_agent import DifyLLMAdapterModel, DifyPluginDaemonProvider
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def load_env_file(path: Path) -> None:
|
||||
"""Load simple KEY=VALUE lines without adding a dotenv dependency."""
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
|
||||
|
||||
|
||||
def required_env(name: str) -> str:
|
||||
value = os.environ.get(name)
|
||||
if value:
|
||||
return value
|
||||
raise RuntimeError(f"Missing required environment variable: {name}")
|
||||
|
||||
|
||||
def load_credentials() -> dict[str, Any]:
|
||||
raw_credentials = required_env("DIFY_AGENT_MODEL_CREDENTIALS_JSON")
|
||||
credentials = json.loads(raw_credentials)
|
||||
if not isinstance(credentials, dict):
|
||||
raise RuntimeError("DIFY_AGENT_MODEL_CREDENTIALS_JSON must be a JSON object")
|
||||
return credentials
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
load_env_file(PROJECT_ROOT / ".env")
|
||||
|
||||
model = DifyLLMAdapterModel(
|
||||
required_env("DIFY_AGENT_MODEL_NAME"),
|
||||
DifyPluginDaemonProvider(
|
||||
tenant_id=required_env("DIFY_AGENT_TENANT_ID"),
|
||||
plugin_id=required_env("DIFY_AGENT_PLUGIN_ID"),
|
||||
plugin_provider=required_env("DIFY_AGENT_PROVIDER"),
|
||||
plugin_daemon_url=required_env("PLUGIN_DAEMON_URL"),
|
||||
plugin_daemon_api_key=required_env("PLUGIN_DAEMON_KEY"),
|
||||
),
|
||||
credentials=load_credentials(),
|
||||
)
|
||||
agent = Agent(model=model)
|
||||
async with agent.run_stream("Explain the theory of relativity") as run:
|
||||
async for piece in run.stream_output():
|
||||
print(piece, end="", flush=True)
|
||||
print(run.usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,59 +0,0 @@
|
||||
"""Example consumer for the Dify Agent run server.
|
||||
|
||||
Requires Redis and a running API server. The server schedules runs in-process, for
|
||||
example:
|
||||
|
||||
uv run --project dify-agent uvicorn dify_agent.server.app:app --reload
|
||||
|
||||
The default request uses the credential-free pydantic-ai TestModel profile. This
|
||||
script prints the created run and every event observed through cursor polling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=30) as client:
|
||||
create_response = await client.post(
|
||||
"/runs",
|
||||
json={
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "You are a concise assistant.",
|
||||
"user": "Say hello from the Dify Agent API server example.",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"agent_profile": {"provider": "test", "output_text": "Hello from the example TestModel."},
|
||||
},
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
run = create_response.json()
|
||||
print("created run", run)
|
||||
|
||||
cursor = "0-0"
|
||||
while True:
|
||||
events_response = await client.get(f"/runs/{run['run_id']}/events", params={"after": cursor})
|
||||
events_response.raise_for_status()
|
||||
page = events_response.json()
|
||||
cursor = page["next_cursor"] or cursor
|
||||
for event in page["events"]:
|
||||
print("event", event)
|
||||
if event["type"] in {"run_succeeded", "run_failed"}:
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,26 +0,0 @@
|
||||
"""SSE consumer sketch for the Dify Agent run server.
|
||||
|
||||
Create a run with ``run_server_consumer.py`` or any HTTP client, then set RUN_ID
|
||||
below and run this script while the server is available. It prints raw SSE frames
|
||||
without requiring model credentials.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
RUN_ID = "replace-with-run-id"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=None) as client:
|
||||
async with client.stream("GET", f"/runs/{RUN_ID}/events/sse") as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,64 +0,0 @@
|
||||
site_name: Dify Agent
|
||||
site_description: Agent runtime and Agenton composition framework documentation
|
||||
strict: true
|
||||
|
||||
repo_name: langgenius/dify
|
||||
repo_url: https://github.com/langgenius/dify
|
||||
edit_uri: edit/main/dify-agent/docs/
|
||||
|
||||
nav:
|
||||
- Home: index.md
|
||||
- Agenton:
|
||||
- Overview: agenton/index.md
|
||||
- Guide: agenton/guide/index.md
|
||||
- API Reference: agenton/api/index.md
|
||||
- Examples: agenton/examples/index.md
|
||||
- Dify Agent:
|
||||
- Overview: dify-agent/index.md
|
||||
- Operations Guide: dify-agent/guide/index.md
|
||||
- Run API: dify-agent/api/index.md
|
||||
- Examples: dify-agent/examples/index.md
|
||||
|
||||
theme:
|
||||
name: material
|
||||
features:
|
||||
- content.code.copy
|
||||
- content.tabs.link
|
||||
- navigation.indexes
|
||||
- navigation.sections
|
||||
- navigation.tracking
|
||||
- toc.follow
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- attr_list
|
||||
- md_in_html
|
||||
- pymdownx.details
|
||||
- pymdownx.highlight:
|
||||
pygments_lang_class: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.superfences
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- tables
|
||||
|
||||
plugins:
|
||||
- search
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
paths:
|
||||
- src
|
||||
options:
|
||||
docstring_style: google
|
||||
members_order: source
|
||||
separate_signature: true
|
||||
show_signature_annotations: true
|
||||
signature_crossrefs: true
|
||||
|
||||
watch:
|
||||
- src
|
||||
- examples
|
||||
|
||||
hooks:
|
||||
- docs/.hooks/main.py
|
||||
@ -1,75 +0,0 @@
|
||||
[project]
|
||||
name = "dify-agent"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"anyio>=4.13.0",
|
||||
"fastapi>=0.136.0",
|
||||
"graphon~=0.2.2",
|
||||
"httpx>=0.28.1",
|
||||
"logfire>=4.32.1",
|
||||
"pydantic>=2.13.3",
|
||||
"pydantic-ai-slim[anthropic,google,openai]>=1.85.1",
|
||||
"pydantic-settings>=2.12.0",
|
||||
"redis>=5",
|
||||
"sqlmodel>=0.0.38",
|
||||
"uvicorn[standard]>=0.38.0",
|
||||
"uvloop>=0.22.1",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src", "examples/agenton", "examples/dify_agent"]
|
||||
include = [
|
||||
"agenton*",
|
||||
"agenton_collections*",
|
||||
"agenton_examples*",
|
||||
"dify_agent*",
|
||||
"dify_agent_examples*",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "examples", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.12"
|
||||
extraPaths = [
|
||||
"src",
|
||||
"examples/agenton",
|
||||
"examples/dify_agent",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py312"
|
||||
include = [
|
||||
"src/**/*.py",
|
||||
"examples/**/*.py",
|
||||
"tests/**/*.py",
|
||||
"docs/**/*.py",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.39.3",
|
||||
"coverage[toml]>=7.10.7",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-examples>=0.0.18",
|
||||
"pytest-mock>=3.14.0",
|
||||
"ruff>=0.15.11",
|
||||
]
|
||||
docs = [
|
||||
"mkdocs>=1.6.1,<2",
|
||||
"mkdocs-glightbox>=0.4.0",
|
||||
"mkdocs-material>=9.7.0",
|
||||
"mkdocstrings-python>=2.0.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@ -1,569 +0,0 @@
|
||||
"""Layer composition primitives.
|
||||
|
||||
The compositor owns a named, ordered set of layers. ``Compositor[PromptT,
|
||||
ToolT, LayerPromptT, LayerToolT]`` is framework-neutral; callers choose layer and
|
||||
exposed prompt/tool item types by annotating construction or assignment sites.
|
||||
When only the first two type arguments are supplied, ``LayerPromptT`` and
|
||||
``LayerToolT`` default to the corresponding exposed item types.
|
||||
|
||||
Layer instances are shared graph/capability definitions owned by the compositor.
|
||||
Per-session runtime state belongs to each session's ``LayerControl`` objects,
|
||||
not to the shared layer instances, so different sessions can enter the same
|
||||
compositor without leaking generated ids or handles through ``self``.
|
||||
|
||||
Dependency mappings use layer-local dependency names as keys and compositor
|
||||
layer names as values. System prompt aggregation depends on insertion order:
|
||||
prefix prompts are collected from first to last layer, while suffix prompts are
|
||||
collected in reverse. User prompts are collected from first to last layer so the
|
||||
composed user message preserves graph order.
|
||||
|
||||
Serializable graph config uses registry type ids rather than import paths.
|
||||
``CompositorBuilder`` resolves config nodes through ``LayerRegistry`` and can
|
||||
mix those nodes with live layer instances for Python objects and callables.
|
||||
|
||||
``Compositor.enter`` enters layers in compositor order and exits them in reverse
|
||||
order through ``AsyncExitStack``. It accepts an optional ``CompositorSession``
|
||||
whose layer controls must match the compositor layer names and order. When
|
||||
omitted, a fresh session is created. Reusing a suspended session resumes its
|
||||
layer contexts; closed sessions must be replaced.
|
||||
|
||||
Optional prompt, user prompt, and tool transformers run after layer aggregation.
|
||||
The compositor asks each layer to ``wrap_prompt``, ``wrap_user_prompt``, and
|
||||
``wrap_tool`` its native values, so typed layer families can tag values without
|
||||
changing their authoring contracts. When transformers are omitted, the
|
||||
compositor returns those wrapped items unchanged.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Mapping as MappingABC, Sequence
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, Mapping, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
from agenton.layers.base import Layer, LayerControl, LifecycleState
|
||||
from agenton.layers.types import AllPromptTypes, AllToolTypes, AllUserPromptTypes
|
||||
|
||||
PromptT = TypeVar("PromptT", default=AllPromptTypes)
|
||||
ToolT = TypeVar("ToolT", default=AllToolTypes)
|
||||
LayerPromptT = TypeVar("LayerPromptT", default=AllPromptTypes)
|
||||
LayerToolT = TypeVar("LayerToolT", default=AllToolTypes)
|
||||
UserPromptT = TypeVar("UserPromptT", default=AllUserPromptTypes)
|
||||
LayerUserPromptT = TypeVar("LayerUserPromptT", default=AllUserPromptTypes)
|
||||
|
||||
|
||||
type CompositorTransformer[InputT, OutputT] = Callable[[Sequence[InputT]], Sequence[OutputT]]
|
||||
|
||||
|
||||
class CompositorTransformerKwargs[
|
||||
PromptT,
|
||||
ToolT,
|
||||
LayerPromptT,
|
||||
LayerToolT,
|
||||
UserPromptT,
|
||||
LayerUserPromptT,
|
||||
](TypedDict):
|
||||
"""Keyword arguments that install prompt, user prompt, and tool transformers."""
|
||||
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT]
|
||||
user_prompt_transformer: CompositorTransformer[LayerUserPromptT, UserPromptT]
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT]
|
||||
|
||||
|
||||
type _ConfigModelValue[ModelT: BaseModel] = ModelT | JsonValue | str | bytes
|
||||
|
||||
|
||||
def _validate_config_model_input[ModelT: BaseModel](
|
||||
model_type: type[ModelT],
|
||||
value: _ConfigModelValue[ModelT] | Mapping[str, object],
|
||||
) -> ModelT:
|
||||
if isinstance(value, model_type):
|
||||
return value
|
||||
if isinstance(value, str | bytes):
|
||||
return model_type.model_validate_json(value)
|
||||
|
||||
return model_type.model_validate(value)
|
||||
|
||||
|
||||
class LayerNodeConfig(BaseModel):
|
||||
"""Serializable config for one registry-backed layer node."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
config: JsonValue = Field(default_factory=dict)
|
||||
deps: Mapping[str, str] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CompositorConfig(BaseModel):
|
||||
"""Serializable config for constructing a compositor graph.
|
||||
|
||||
The graph references layer implementations by registry type id. Live Python
|
||||
objects and callables are intentionally excluded; compose those with
|
||||
``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
schema_version: int = 1
|
||||
layers: list[LayerNodeConfig]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
type CompositorConfigValue = _ConfigModelValue[CompositorConfig] | Mapping[str, object]
|
||||
|
||||
|
||||
def _validate_compositor_config_input(value: CompositorConfigValue) -> CompositorConfig:
|
||||
return _validate_config_model_input(CompositorConfig, value)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LayerDescriptor:
|
||||
"""Registry descriptor inferred from a layer class."""
|
||||
|
||||
type_id: str
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]]
|
||||
config_type: type[BaseModel]
|
||||
runtime_state_type: type[BaseModel]
|
||||
runtime_handles_type: type[BaseModel]
|
||||
|
||||
|
||||
class LayerRegistry:
|
||||
"""Manual registry for config-constructible layer classes.
|
||||
|
||||
Registration infers config and runtime schemas from layer class attributes.
|
||||
A registered layer must have a type id, either declared as ``type_id`` on the
|
||||
class or supplied to ``register_layer``.
|
||||
"""
|
||||
|
||||
__slots__ = ("_descriptors",)
|
||||
|
||||
_descriptors: dict[str, LayerDescriptor]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._descriptors = {}
|
||||
|
||||
def register_layer(
|
||||
self,
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]],
|
||||
*,
|
||||
type_id: str | None = None,
|
||||
) -> None:
|
||||
"""Register ``layer_type`` under its inferred or explicit type id."""
|
||||
resolved_type_id = type_id or layer_type.type_id
|
||||
if resolved_type_id is not None and not isinstance(resolved_type_id, str):
|
||||
raise TypeError(f"Layer type id for '{layer_type.__qualname__}' must be a string.")
|
||||
if resolved_type_id is None or not resolved_type_id:
|
||||
raise ValueError(f"Layer '{layer_type.__qualname__}' must declare a type_id or be registered with one.")
|
||||
if resolved_type_id in self._descriptors:
|
||||
raise ValueError(f"Layer type id '{resolved_type_id}' is already registered.")
|
||||
self._descriptors[resolved_type_id] = LayerDescriptor(
|
||||
type_id=resolved_type_id,
|
||||
layer_type=layer_type,
|
||||
config_type=layer_type.config_type,
|
||||
runtime_state_type=layer_type.runtime_state_type,
|
||||
runtime_handles_type=layer_type.runtime_handles_type,
|
||||
)
|
||||
|
||||
def resolve(self, type_id: str) -> LayerDescriptor:
|
||||
"""Return the descriptor for ``type_id`` or raise ``KeyError``."""
|
||||
try:
|
||||
return self._descriptors[type_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Layer type id '{type_id}' is not registered.") from e
|
||||
|
||||
def descriptors(self) -> Mapping[str, LayerDescriptor]:
|
||||
"""Return registered descriptors keyed by type id."""
|
||||
return dict(self._descriptors)
|
||||
|
||||
|
||||
class CompositorSession:
|
||||
"""External lifecycle session for layer contexts entered by a compositor.
|
||||
|
||||
A session owns one ``LayerControl`` per compositor layer name, preserving
|
||||
compositor order. Controls must be created from the matching layer schemas;
|
||||
prefer ``Compositor.new_session`` or ``Compositor.session_from_snapshot`` for
|
||||
public session construction. Broadcast methods are convenience APIs for
|
||||
setting every layer's per-entry exit intent; ``layer`` allows explicit
|
||||
per-layer control when callers need partial suspend/delete behavior. A mixed
|
||||
session with any closed layer cannot be entered again because compositor
|
||||
entry is all-or-none.
|
||||
"""
|
||||
|
||||
__slots__ = ("layer_controls",)
|
||||
|
||||
layer_controls: OrderedDict[str, LayerControl]
|
||||
|
||||
def __init__(self, layer_names: Iterable[str] | Mapping[str, LayerControl]) -> None:
|
||||
if isinstance(layer_names, MappingABC):
|
||||
self.layer_controls = OrderedDict(layer_names.items())
|
||||
return
|
||||
self.layer_controls = OrderedDict((layer_name, LayerControl()) for layer_name in layer_names)
|
||||
|
||||
def suspend_on_exit(self) -> None:
|
||||
"""Request suspend behavior for every layer when this entry exits."""
|
||||
for control in self.layer_controls.values():
|
||||
control.suspend_on_exit()
|
||||
|
||||
def delete_on_exit(self) -> None:
|
||||
"""Request delete behavior for every layer when this entry exits."""
|
||||
for control in self.layer_controls.values():
|
||||
control.delete_on_exit()
|
||||
|
||||
def layer(self, name: str) -> LayerControl:
|
||||
"""Return the layer control for ``name`` or raise ``KeyError``."""
|
||||
return self.layer_controls[name]
|
||||
|
||||
|
||||
class LayerSessionSnapshot(BaseModel):
|
||||
"""Serializable snapshot for one layer control."""
|
||||
|
||||
name: str
|
||||
state: LifecycleState
|
||||
runtime_state: dict[str, JsonValue]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CompositorSessionSnapshot(BaseModel):
|
||||
"""Serializable compositor session snapshot.
|
||||
|
||||
Snapshots include runtime state only. Live runtime handles are intentionally
|
||||
excluded and must be rehydrated by resume hooks using runtime state.
|
||||
"""
|
||||
|
||||
schema_version: int = 1
|
||||
layers: list[LayerSessionSnapshot]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _LayerBuildEntry:
|
||||
name: str
|
||||
layer: Layer[Any, Any, Any, Any, Any, Any, Any]
|
||||
deps: Mapping[str, str]
|
||||
|
||||
|
||||
class CompositorBuilder:
|
||||
"""Build compositors from registry config nodes and live instances."""
|
||||
|
||||
__slots__ = ("_registry", "_entries")
|
||||
|
||||
_registry: LayerRegistry
|
||||
_entries: list[_LayerBuildEntry]
|
||||
|
||||
def __init__(self, registry: LayerRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._entries = []
|
||||
|
||||
def add_config(self, config: CompositorConfigValue) -> Self:
|
||||
"""Add all layers from a serializable compositor config."""
|
||||
conf = _validate_compositor_config_input(config)
|
||||
if conf.schema_version != 1:
|
||||
raise ValueError(f"Unsupported compositor config schema_version: {conf.schema_version}.")
|
||||
for layer_conf in conf.layers:
|
||||
self.add_config_layer(
|
||||
name=layer_conf.name,
|
||||
type=layer_conf.type,
|
||||
config=layer_conf.config,
|
||||
deps=layer_conf.deps,
|
||||
)
|
||||
return self
|
||||
|
||||
def add_config_layer(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
type: str,
|
||||
config: object | None = None,
|
||||
deps: Mapping[str, str] | None = None,
|
||||
) -> Self:
|
||||
"""Resolve, validate, and add one registry-backed layer config node."""
|
||||
descriptor = self._registry.resolve(type)
|
||||
raw_config = {} if config is None else config
|
||||
validated_config = descriptor.config_type.model_validate(raw_config)
|
||||
layer = descriptor.layer_type.from_config(cast(Any, validated_config))
|
||||
self.add_instance(name=name, layer=layer, deps=deps)
|
||||
return self
|
||||
|
||||
def add_instance(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
layer: Layer[Any, Any, Any, Any, Any, Any, Any],
|
||||
deps: Mapping[str, str] | None = None,
|
||||
) -> Self:
|
||||
"""Add a live layer instance, useful for Python objects and callables."""
|
||||
self._entries.append(_LayerBuildEntry(name=name, layer=layer, deps=dict(deps or {})))
|
||||
return self
|
||||
|
||||
def build[PromptT, ToolT, LayerPromptT, LayerToolT, UserPromptT, LayerUserPromptT](
|
||||
self,
|
||||
*,
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
|
||||
user_prompt_transformer: CompositorTransformer[LayerUserPromptT, UserPromptT] | None = None,
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
|
||||
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT, UserPromptT, LayerUserPromptT]":
|
||||
"""Validate names/dependencies, bind deps, and return a compositor."""
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any, Any]] = OrderedDict()
|
||||
deps_name_mapping: dict[str, Mapping[str, str]] = {}
|
||||
for entry in self._entries:
|
||||
if entry.name in layers:
|
||||
raise ValueError(f"Duplicate layer name '{entry.name}'.")
|
||||
layers[entry.name] = entry.layer
|
||||
deps_name_mapping[entry.name] = entry.deps
|
||||
|
||||
layer_names = set(layers)
|
||||
for layer_name, deps in deps_name_mapping.items():
|
||||
declared_deps = layers[layer_name].dependency_names()
|
||||
unknown_dep_keys = set(deps) - declared_deps
|
||||
if unknown_dep_keys:
|
||||
names = ", ".join(sorted(unknown_dep_keys))
|
||||
raise ValueError(f"Layer '{layer_name}' declares unknown dependency keys: {names}.")
|
||||
missing_targets = set(deps.values()) - layer_names
|
||||
if missing_targets:
|
||||
names = ", ".join(sorted(missing_targets))
|
||||
raise ValueError(f"Layer '{layer_name}' depends on undefined layer names: {names}.")
|
||||
|
||||
return Compositor(
|
||||
layers=layers,
|
||||
deps_name_mapping=deps_name_mapping,
|
||||
prompt_transformer=prompt_transformer,
|
||||
user_prompt_transformer=user_prompt_transformer,
|
||||
tool_transformer=tool_transformer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT, UserPromptT, LayerUserPromptT]):
|
||||
"""Framework-neutral ordered layer graph with lifecycle and aggregation.
|
||||
|
||||
``prompt_transformer``, ``user_prompt_transformer``, and
|
||||
``tool_transformer`` are post-aggregation hooks: they run whenever
|
||||
``prompts``, ``user_prompts``, or ``tools`` is read, after layer
|
||||
contributions have been collected in compositor order. Use two type
|
||||
arguments for identity aggregation, four when prompt/tool layer item types
|
||||
differ from exposed item types, or all six when user prompt item types also
|
||||
differ.
|
||||
"""
|
||||
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any, Any]]
|
||||
deps_name_mapping: Mapping[str, Mapping[str, str]] = field(default_factory=dict)
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None
|
||||
user_prompt_transformer: CompositorTransformer[LayerUserPromptT, UserPromptT] | None = None
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None
|
||||
_deps_bound: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._bind_deps(self.deps_name_mapping)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
conf: CompositorConfigValue,
|
||||
*,
|
||||
registry: LayerRegistry,
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
|
||||
user_prompt_transformer: CompositorTransformer[LayerUserPromptT, UserPromptT] | None = None,
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
|
||||
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT, UserPromptT, LayerUserPromptT]":
|
||||
"""Create a compositor from registry-backed serializable config."""
|
||||
return CompositorBuilder(registry).add_config(conf).build(
|
||||
prompt_transformer=prompt_transformer,
|
||||
user_prompt_transformer=user_prompt_transformer,
|
||||
tool_transformer=tool_transformer,
|
||||
)
|
||||
|
||||
def _bind_deps(self, deps_name_mapping: Mapping[str, Mapping[str, str]]) -> None:
|
||||
"""Resolve dependency-name mappings and bind dependencies on each layer.
|
||||
|
||||
The outer mapping key is the layer being bound. The inner mapping key is
|
||||
the dependency field declared by that layer's deps type, and the value is
|
||||
the target layer name in this compositor.
|
||||
"""
|
||||
if self._deps_bound:
|
||||
raise RuntimeError("Compositor deps are already bound.")
|
||||
|
||||
for layer_name, layer in self.layers.items():
|
||||
layer_deps = deps_name_mapping.get(layer_name, {})
|
||||
try:
|
||||
deps = {
|
||||
dep_name: self.layers[target_layer_name]
|
||||
for dep_name, target_layer_name in layer_deps.items()
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Layer '{layer_name}' has a dependency on layer '{e.args[0]}', "
|
||||
"which is not defined in the builder."
|
||||
) from e
|
||||
layer.bind_deps({**self.layers, **deps})
|
||||
self._deps_bound = True
|
||||
|
||||
def new_session(self) -> CompositorSession:
|
||||
"""Create a fresh lifecycle session matching this compositor's layer order."""
|
||||
return CompositorSession(
|
||||
OrderedDict((layer_name, layer.new_control()) for layer_name, layer in self.layers.items())
|
||||
)
|
||||
|
||||
def snapshot_session(self, session: CompositorSession) -> CompositorSessionSnapshot:
|
||||
"""Serialize non-active session lifecycle state and runtime state.
|
||||
|
||||
Runtime handles are live Python objects and are intentionally excluded.
|
||||
"""
|
||||
self._validate_session(session)
|
||||
active_layers = [name for name, control in session.layer_controls.items() if control.state is LifecycleState.ACTIVE]
|
||||
if active_layers:
|
||||
names = ", ".join(active_layers)
|
||||
raise RuntimeError(f"Cannot snapshot active compositor session layers: {names}.")
|
||||
return CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(
|
||||
name=name,
|
||||
state=control.state,
|
||||
runtime_state=cast(dict[str, JsonValue], control.runtime_state.model_dump(mode="json")),
|
||||
)
|
||||
for name, control in session.layer_controls.items()
|
||||
]
|
||||
)
|
||||
|
||||
def session_from_snapshot(self, snapshot: CompositorSessionSnapshot | JsonValue | str | bytes) -> CompositorSession:
|
||||
"""Restore a session from a snapshot and reinitialize empty handles."""
|
||||
snapshot = _validate_config_model_input(CompositorSessionSnapshot, snapshot)
|
||||
if snapshot.schema_version != 1:
|
||||
raise ValueError(f"Unsupported compositor session snapshot schema_version: {snapshot.schema_version}.")
|
||||
snapshot_layer_names = tuple(layer.name for layer in snapshot.layers)
|
||||
expected_layer_names = tuple(self.layers)
|
||||
if snapshot_layer_names != expected_layer_names:
|
||||
expected = ", ".join(expected_layer_names)
|
||||
actual = ", ".join(snapshot_layer_names)
|
||||
raise ValueError(
|
||||
"CompositorSessionSnapshot layer names must match compositor layers in order. "
|
||||
f"Expected [{expected}], got [{actual}]."
|
||||
)
|
||||
active_layers = [layer.name for layer in snapshot.layers if layer.state is LifecycleState.ACTIVE]
|
||||
if active_layers:
|
||||
names = ", ".join(active_layers)
|
||||
raise ValueError(f"Cannot restore active compositor session layers from snapshot: {names}.")
|
||||
controls = OrderedDict(
|
||||
(
|
||||
layer_snapshot.name,
|
||||
self.layers[layer_snapshot.name].new_control(
|
||||
state=layer_snapshot.state,
|
||||
runtime_state=layer_snapshot.runtime_state,
|
||||
),
|
||||
)
|
||||
for layer_snapshot in snapshot.layers
|
||||
)
|
||||
return CompositorSession(controls)
|
||||
|
||||
@asynccontextmanager
|
||||
async def enter(
|
||||
self,
|
||||
session: CompositorSession | None = None,
|
||||
) -> AsyncIterator[CompositorSession]:
|
||||
"""Enter each layer context in order and yield the active session."""
|
||||
if not self._deps_bound:
|
||||
raise RuntimeError("Compositor deps must be bound before entering context.")
|
||||
|
||||
if session is None:
|
||||
session = self.new_session()
|
||||
self._validate_session(session)
|
||||
self._ensure_session_can_enter(session)
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
for layer_name, layer in self.layers.items():
|
||||
await stack.enter_async_context(layer.enter(session.layer_controls[layer_name]))
|
||||
yield session
|
||||
|
||||
def _validate_session(self, session: CompositorSession) -> None:
|
||||
expected_layer_names = tuple(self.layers)
|
||||
actual_layer_names = tuple(session.layer_controls)
|
||||
if actual_layer_names != expected_layer_names:
|
||||
expected = ", ".join(expected_layer_names)
|
||||
actual = ", ".join(actual_layer_names)
|
||||
raise ValueError(
|
||||
"CompositorSession layer names must match compositor layers in order. "
|
||||
f"Expected [{expected}], got [{actual}]."
|
||||
)
|
||||
for layer_name, layer in self.layers.items():
|
||||
control = session.layer_controls[layer_name]
|
||||
if not isinstance(control.runtime_state, layer.runtime_state_type):
|
||||
raise TypeError(
|
||||
f"CompositorSession layer '{layer_name}' runtime_state must be "
|
||||
f"{layer.runtime_state_type.__name__}, got {type(control.runtime_state).__name__}."
|
||||
)
|
||||
if not isinstance(control.runtime_handles, layer.runtime_handles_type):
|
||||
raise TypeError(
|
||||
f"CompositorSession layer '{layer_name}' runtime_handles must be "
|
||||
f"{layer.runtime_handles_type.__name__}, got {type(control.runtime_handles).__name__}."
|
||||
)
|
||||
|
||||
def _ensure_session_can_enter(self, session: CompositorSession) -> None:
|
||||
"""Reject active or closed layer controls before any layer side effects."""
|
||||
for control in session.layer_controls.values():
|
||||
if control.state is LifecycleState.ACTIVE:
|
||||
raise RuntimeError(
|
||||
"LayerControl is already active; duplicate or nested enter is not allowed."
|
||||
)
|
||||
if control.state is LifecycleState.CLOSED:
|
||||
raise RuntimeError(
|
||||
"LayerControl is closed; create a new compositor session before entering again."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompts(self) -> list[PromptT]:
|
||||
result: list[LayerPromptT] = []
|
||||
for layer in self.layers.values():
|
||||
result.extend(
|
||||
cast(LayerPromptT, layer.wrap_prompt(prompt))
|
||||
for prompt in layer.prefix_prompts
|
||||
)
|
||||
for layer in reversed(self.layers.values()):
|
||||
result.extend(
|
||||
cast(LayerPromptT, layer.wrap_prompt(prompt))
|
||||
for prompt in layer.suffix_prompts
|
||||
)
|
||||
if self.prompt_transformer is None:
|
||||
return cast(list[PromptT], result)
|
||||
return list(self.prompt_transformer(result))
|
||||
|
||||
@property
|
||||
def user_prompts(self) -> list[UserPromptT]:
|
||||
result: list[LayerUserPromptT] = []
|
||||
for layer in self.layers.values():
|
||||
result.extend(
|
||||
cast(LayerUserPromptT, layer.wrap_user_prompt(prompt))
|
||||
for prompt in layer.user_prompts
|
||||
)
|
||||
if self.user_prompt_transformer is None:
|
||||
return cast(list[UserPromptT], result)
|
||||
return list(self.user_prompt_transformer(result))
|
||||
|
||||
@property
|
||||
def tools(self) -> list[ToolT]:
|
||||
result: list[LayerToolT] = []
|
||||
for layer in self.layers.values():
|
||||
result.extend(cast(LayerToolT, layer.wrap_tool(tool)) for tool in layer.tools)
|
||||
if self.tool_transformer is None:
|
||||
return cast(list[ToolT], result)
|
||||
return list(self.tool_transformer(result))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Compositor",
|
||||
"CompositorBuilder",
|
||||
"CompositorConfig",
|
||||
"CompositorConfigValue",
|
||||
"CompositorSessionSnapshot",
|
||||
"CompositorSession",
|
||||
"CompositorTransformer",
|
||||
"CompositorTransformerKwargs",
|
||||
"LayerDescriptor",
|
||||
"LayerNodeConfig",
|
||||
"LayerRegistry",
|
||||
"LayerSessionSnapshot",
|
||||
]
|
||||
@ -1,66 +0,0 @@
|
||||
"""Layer base classes and typed layer families.
|
||||
|
||||
``agenton.layers.base`` owns the framework-neutral ``Layer`` abstraction.
|
||||
``agenton.layers.types`` binds the prompt/tool generic slots to specific layer
|
||||
families while keeping concrete reusable layers in ``agenton_collections``.
|
||||
"""
|
||||
|
||||
from agenton.layers.base import (
|
||||
EmptyLayerConfig,
|
||||
EmptyRuntimeHandles,
|
||||
EmptyRuntimeState,
|
||||
ExitIntent,
|
||||
Layer,
|
||||
LayerControl,
|
||||
LayerDeps,
|
||||
LifecycleState,
|
||||
NoLayerDeps,
|
||||
)
|
||||
from agenton.layers.types import (
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
AllUserPromptTypes,
|
||||
PlainLayer,
|
||||
PlainPrompt,
|
||||
PlainPromptType,
|
||||
PlainTool,
|
||||
PlainToolType,
|
||||
PlainUserPrompt,
|
||||
PlainUserPromptType,
|
||||
PydanticAILayer,
|
||||
PydanticAIPrompt,
|
||||
PydanticAIPromptType,
|
||||
PydanticAITool,
|
||||
PydanticAIToolType,
|
||||
PydanticAIUserPrompt,
|
||||
PydanticAIUserPromptType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllPromptTypes",
|
||||
"AllToolTypes",
|
||||
"AllUserPromptTypes",
|
||||
"Layer",
|
||||
"LayerDeps",
|
||||
"LayerControl",
|
||||
"LifecycleState",
|
||||
"ExitIntent",
|
||||
"EmptyLayerConfig",
|
||||
"EmptyRuntimeState",
|
||||
"EmptyRuntimeHandles",
|
||||
"NoLayerDeps",
|
||||
"PlainLayer",
|
||||
"PlainPrompt",
|
||||
"PlainPromptType",
|
||||
"PlainUserPrompt",
|
||||
"PlainUserPromptType",
|
||||
"PlainTool",
|
||||
"PlainToolType",
|
||||
"PydanticAILayer",
|
||||
"PydanticAIPrompt",
|
||||
"PydanticAIPromptType",
|
||||
"PydanticAIUserPrompt",
|
||||
"PydanticAIUserPromptType",
|
||||
"PydanticAITool",
|
||||
"PydanticAIToolType",
|
||||
]
|
||||
@ -1,548 +0,0 @@
|
||||
"""Core layer abstractions and typed dependency binding.
|
||||
|
||||
Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT, ...]``.
|
||||
``DepsT`` must be a ``LayerDeps`` subclass whose annotated members are concrete
|
||||
``Layer`` subclasses or modern optional dependencies such as ``SomeLayer |
|
||||
None``. The optional trailing generic slots declare Pydantic schemas for config,
|
||||
serializable runtime state, and live runtime handles. The base class infers
|
||||
``deps_type`` and schema class attributes from the generic base when possible,
|
||||
while still allowing subclasses to set them explicitly for unusual inheritance
|
||||
patterns.
|
||||
|
||||
``Layer.bind_deps`` is the mutation point for dependency state. Layer
|
||||
implementations should treat ``self.deps`` as unavailable until a compositor or
|
||||
caller has resolved and bound dependencies.
|
||||
|
||||
Layer async entry uses a caller-provided ``LayerControl`` as an explicit state
|
||||
machine and per-session runtime owner. A fresh control starts in
|
||||
``LifecycleState.NEW`` and enters create logic. A suspended control resumes,
|
||||
while active or closed controls are rejected to prevent ambiguous nested or
|
||||
post-delete reuse. Exit behavior is selected per entry with ``ExitIntent`` and
|
||||
resets to delete on every successful enter. Layer instances are shared graph and
|
||||
capability definitions, so session-local serializable ids, checkpoints, and
|
||||
other snapshot data belong in ``LayerControl.runtime_state``; live clients,
|
||||
connections, and process handles belong in ``LayerControl.runtime_handles``.
|
||||
Neither category should be stored on ``self`` when it is session-local.
|
||||
|
||||
``Layer`` is framework-neutral over system prompt, user prompt, and tool item
|
||||
types. The native ``prefix_prompts``, ``suffix_prompts``, ``user_prompts``, and
|
||||
``tools`` properties are the layer authoring surface. ``wrap_prompt``,
|
||||
``wrap_user_prompt``, and ``wrap_tool`` are the compositor aggregation surface;
|
||||
typed families such as ``agenton.layers.types.PlainLayer`` implement them to tag
|
||||
native values without changing layer implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from types import UnionType
|
||||
from typing import Any, ClassVar, Generic, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
|
||||
_DepsT = TypeVar("_DepsT", bound="LayerDeps")
|
||||
_PromptT = TypeVar("_PromptT")
|
||||
_UserPromptT = TypeVar("_UserPromptT")
|
||||
_ToolT = TypeVar("_ToolT")
|
||||
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default="EmptyLayerConfig")
|
||||
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default="EmptyRuntimeState")
|
||||
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default="EmptyRuntimeHandles")
|
||||
|
||||
|
||||
class LayerDeps:
|
||||
"""Typed dependency container for a Layer.
|
||||
|
||||
Subclasses declare dependency members with annotations. Every annotated
|
||||
member must be a Layer subclass or ``LayerSubclass | None``. Optional deps
|
||||
are always assigned as attributes; missing optional values become ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, **deps: "Layer[Any, Any, Any, Any, Any, Any, Any] | None") -> None:
|
||||
dep_specs = _get_dep_specs(type(self))
|
||||
missing_names = {name for name, spec in dep_specs.items() if not spec.optional} - deps.keys()
|
||||
if missing_names:
|
||||
names = ", ".join(sorted(missing_names))
|
||||
raise ValueError(f"Missing layer dependencies: {names}.")
|
||||
|
||||
unknown_names = deps.keys() - dep_specs.keys()
|
||||
if unknown_names:
|
||||
names = ", ".join(sorted(unknown_names))
|
||||
raise ValueError(f"Unknown layer dependencies: {names}.")
|
||||
|
||||
for name, spec in dep_specs.items():
|
||||
value = deps.get(name)
|
||||
if value is None:
|
||||
if spec.optional:
|
||||
setattr(self, name, None)
|
||||
continue
|
||||
raise ValueError(f"Dependency '{name}' is required but not provided.")
|
||||
|
||||
if not isinstance(value, spec.layer_type):
|
||||
raise TypeError(
|
||||
f"Dependency '{name}' should be of type '{spec.layer_type.__name__}', "
|
||||
f"but got type '{type(value).__name__}'."
|
||||
)
|
||||
setattr(self, name, value)
|
||||
|
||||
|
||||
class NoLayerDeps(LayerDeps):
|
||||
"""Dependency container for layers that do not require other layers."""
|
||||
|
||||
|
||||
class EmptyLayerConfig(BaseModel):
|
||||
"""Default serializable config schema for layers without config."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class EmptyRuntimeState(BaseModel):
|
||||
"""Default serializable per-session runtime state schema."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class EmptyRuntimeHandles(BaseModel):
|
||||
"""Default live per-session runtime handle schema.
|
||||
|
||||
Handles may contain arbitrary Python objects and are intentionally excluded
|
||||
from session snapshots.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class LifecycleState(StrEnum):
|
||||
"""Externally observable lifecycle state for a layer control."""
|
||||
|
||||
NEW = "new"
|
||||
ACTIVE = "active"
|
||||
SUSPENDED = "suspended"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class ExitIntent(StrEnum):
|
||||
"""Per-entry exit behavior requested for a layer control."""
|
||||
|
||||
DELETE = "delete"
|
||||
SUSPEND = "suspend"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerControl(Generic[_RuntimeStateT, _RuntimeHandlesT]):
|
||||
"""Stateful control slot passed into a layer entry context.
|
||||
|
||||
``Layer.enter`` requires the caller to provide this object. The control owns
|
||||
the layer lifecycle state, the current entry's exit intent, and arbitrary
|
||||
per-session runtime state and live handles. Call ``suspend_on_exit`` before leaving the
|
||||
context to make a later entry resume; call ``delete_on_exit`` or do nothing
|
||||
for the default delete behavior. Store session-local serializable ids,
|
||||
checkpoints, and other snapshot data in ``runtime_state``. Store live
|
||||
clients, connections, process handles, and other non-serializable objects in
|
||||
``runtime_handles``. Do not put either kind of session-local data on the
|
||||
shared layer instance.
|
||||
|
||||
``runtime_state`` intentionally persists after suspend and delete. Suspend,
|
||||
resume, and delete hooks can inspect the same values created on entry, and
|
||||
callers may inspect closed-session diagnostics after exit. Reuse is still
|
||||
governed by ``state``: a closed control cannot be entered again. Runtime
|
||||
handles are not serialized in snapshots and should be rehydrated from
|
||||
runtime state in resume hooks.
|
||||
"""
|
||||
|
||||
state: LifecycleState = LifecycleState.NEW
|
||||
exit_intent: ExitIntent = ExitIntent.DELETE
|
||||
runtime_state: _RuntimeStateT = field(default_factory=lambda: cast(_RuntimeStateT, EmptyRuntimeState()))
|
||||
runtime_handles: _RuntimeHandlesT = field(default_factory=lambda: cast(_RuntimeHandlesT, EmptyRuntimeHandles()))
|
||||
|
||||
def suspend_on_exit(self) -> None:
|
||||
"""Request suspend behavior when the current layer entry exits."""
|
||||
self.exit_intent = ExitIntent.SUSPEND
|
||||
|
||||
def delete_on_exit(self) -> None:
|
||||
"""Request delete behavior when the current layer entry exits."""
|
||||
self.exit_intent = ExitIntent.DELETE
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LayerDepSpec:
|
||||
"""Runtime dependency specification derived from a deps annotation."""
|
||||
|
||||
layer_type: type["Layer[Any, Any, Any, Any, Any, Any, Any]"]
|
||||
optional: bool = False
|
||||
|
||||
|
||||
class Layer(
|
||||
ABC,
|
||||
Generic[_DepsT, _PromptT, _UserPromptT, _ToolT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
):
|
||||
"""Framework-neutral base class for prompt/tool layers.
|
||||
|
||||
Subclasses expose optional prompt fragments and tools through typed
|
||||
properties. They declare required dependencies in the ``DepsT`` container
|
||||
rather than by accepting dependencies in ``__init__``. Layer instances can be
|
||||
entered by multiple sessions, including concurrently, so lifecycle hooks
|
||||
should store session-local runtime values on the passed ``LayerControl``.
|
||||
The default async context manager handles create, resume, suspend, and
|
||||
delete transitions; layers can override ``enter`` when they need to wrap
|
||||
extra runtime resources.
|
||||
"""
|
||||
|
||||
deps_type: type[_DepsT]
|
||||
deps: _DepsT
|
||||
type_id: ClassVar[str | None] = None
|
||||
config_type: ClassVar[type[BaseModel]] = EmptyLayerConfig
|
||||
runtime_state_type: ClassVar[type[BaseModel]] = EmptyRuntimeState
|
||||
runtime_handles_type: ClassVar[type[BaseModel]] = EmptyRuntimeHandles
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
is_generic_template = _is_generic_layer_template(cls)
|
||||
deps_type = cls.__dict__.get("deps_type")
|
||||
if deps_type is None:
|
||||
deps_type = _infer_deps_type(cls) or getattr(cls, "deps_type", None)
|
||||
if deps_type is None and is_generic_template:
|
||||
return
|
||||
if deps_type is not None:
|
||||
cls.deps_type = deps_type # pyright: ignore[reportAttributeAccessIssue]
|
||||
if deps_type is None:
|
||||
raise TypeError(f"{cls.__name__} must define deps_type or inherit from Layer[DepsT].")
|
||||
if not isinstance(deps_type, type) or not issubclass(deps_type, LayerDeps):
|
||||
raise TypeError(f"{cls.__name__}.deps_type must be a LayerDeps subclass.")
|
||||
_get_dep_specs(deps_type)
|
||||
_init_schema_type(cls, "config_type", _infer_schema_type(cls, 4, "config_type"), EmptyLayerConfig)
|
||||
_init_schema_type(
|
||||
cls,
|
||||
"runtime_state_type",
|
||||
_infer_schema_type(cls, 5, "runtime_state_type"),
|
||||
EmptyRuntimeState,
|
||||
)
|
||||
_init_schema_type(
|
||||
cls,
|
||||
"runtime_handles_type",
|
||||
_infer_schema_type(cls, 6, "runtime_handles_type"),
|
||||
EmptyRuntimeHandles,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: type[Self], config: _ConfigT) -> Self:
|
||||
"""Create a layer from schema-validated serialized config.
|
||||
|
||||
Registries/builders validate raw config with ``config_type`` before
|
||||
calling this method. Layers are not config-constructible by default.
|
||||
Subclasses that accept config should override this method and consume
|
||||
the typed Pydantic model for their schema.
|
||||
"""
|
||||
raise TypeError(f"{cls.__name__} cannot be created from config.")
|
||||
|
||||
@classmethod
|
||||
def dependency_names(cls) -> frozenset[str]:
|
||||
"""Return dependency field names declared by this layer's deps schema."""
|
||||
return frozenset(_get_dep_specs(cls.deps_type))
|
||||
|
||||
def new_control(
|
||||
self,
|
||||
*,
|
||||
state: LifecycleState = LifecycleState.NEW,
|
||||
runtime_state: object | None = None,
|
||||
) -> LayerControl[_RuntimeStateT, _RuntimeHandlesT]:
|
||||
"""Create a schema-validated per-session control for this layer.
|
||||
|
||||
``runtime_state`` is validated through ``runtime_state_type`` and live
|
||||
handles are always initialized empty through ``runtime_handles_type``.
|
||||
"""
|
||||
raw_runtime_state = {} if runtime_state is None else runtime_state
|
||||
return LayerControl(
|
||||
state=state,
|
||||
exit_intent=ExitIntent.DELETE,
|
||||
runtime_state=cast(_RuntimeStateT, self.runtime_state_type.model_validate(raw_runtime_state)),
|
||||
runtime_handles=cast(_RuntimeHandlesT, self.runtime_handles_type.model_validate({})),
|
||||
)
|
||||
|
||||
def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any, Any, Any, Any, Any] | None"]) -> None:
|
||||
"""Bind this layer's declared dependencies from a name-to-layer mapping.
|
||||
|
||||
The mapping may include more layers than the declared dependency fields.
|
||||
Only names declared by ``deps_type`` are selected and validated. Missing
|
||||
optional deps are bound as ``None``.
|
||||
"""
|
||||
resolved_deps: dict[str, Layer[Any, Any, Any, Any, Any, Any, Any] | None] = {}
|
||||
for name, spec in _get_dep_specs(self.deps_type).items():
|
||||
if name not in deps:
|
||||
if spec.optional:
|
||||
resolved_deps[name] = None
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Dependency '{name}' is required for layer '{type(self).__name__}' but not provided."
|
||||
)
|
||||
resolved_deps[name] = deps[name]
|
||||
self.deps = self.deps_type(**resolved_deps)
|
||||
|
||||
def enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AbstractAsyncContextManager[None]:
|
||||
"""Return the layer's async entry context manager.
|
||||
|
||||
``control`` is the lifecycle control slot for this entry. Subclasses can
|
||||
override this to wrap extra async resources around
|
||||
``self.lifecycle_enter(control)``.
|
||||
"""
|
||||
return self.lifecycle_enter(control)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifecycle_enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AsyncIterator[None]:
|
||||
"""Run the default explicit lifecycle state machine for one entry."""
|
||||
if control.state is LifecycleState.NEW:
|
||||
control.exit_intent = ExitIntent.DELETE
|
||||
await self.on_context_create(control)
|
||||
control.state = LifecycleState.ACTIVE
|
||||
elif control.state is LifecycleState.SUSPENDED:
|
||||
control.exit_intent = ExitIntent.DELETE
|
||||
await self.on_context_resume(control)
|
||||
control.state = LifecycleState.ACTIVE
|
||||
elif control.state is LifecycleState.ACTIVE:
|
||||
raise RuntimeError(
|
||||
"LayerControl is already active; duplicate or nested enter is not allowed."
|
||||
)
|
||||
elif control.state is LifecycleState.CLOSED:
|
||||
raise RuntimeError(
|
||||
"LayerControl is closed; create a new compositor session before entering again."
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if control.exit_intent is ExitIntent.SUSPEND:
|
||||
await self.on_context_suspend(control)
|
||||
control.state = LifecycleState.SUSPENDED
|
||||
else:
|
||||
await self.on_context_delete(control)
|
||||
control.state = LifecycleState.CLOSED
|
||||
|
||||
async def on_context_create(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context is entered from ``LifecycleState.NEW``."""
|
||||
|
||||
async def on_context_delete(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context exits with ``ExitIntent.DELETE``."""
|
||||
|
||||
async def on_context_suspend(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context exits with ``ExitIntent.SUSPEND``."""
|
||||
|
||||
async def on_context_resume(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context enters from ``LifecycleState.SUSPENDED``."""
|
||||
|
||||
@property
|
||||
def prefix_prompts(self) -> Sequence[_PromptT]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def suffix_prompts(self) -> Sequence[_PromptT]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def user_prompts(self) -> Sequence[_UserPromptT]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def tools(self) -> Sequence[_ToolT]:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def wrap_prompt(self, prompt: _PromptT) -> object:
|
||||
"""Wrap a native prompt item for compositor aggregation."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wrap_user_prompt(self, prompt: _UserPromptT) -> object:
|
||||
"""Wrap a native user prompt item for compositor aggregation."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wrap_tool(self, tool: _ToolT) -> object:
|
||||
"""Wrap a native tool item for compositor aggregation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_dep_specs(deps_type: type[LayerDeps]) -> dict[str, LayerDepSpec]:
|
||||
dep_specs: dict[str, LayerDepSpec] = {}
|
||||
for name, annotation in get_type_hints(deps_type).items():
|
||||
spec = _as_dep_spec(annotation)
|
||||
if spec is None:
|
||||
raise TypeError(
|
||||
f"{deps_type.__name__}.{name} must be annotated with a Layer subclass "
|
||||
"or Layer subclass | None."
|
||||
)
|
||||
dep_specs[name] = spec
|
||||
return dep_specs
|
||||
|
||||
|
||||
def _as_dep_spec(annotation: object) -> LayerDepSpec | None:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
if origin in (UnionType, Union) and len(args) == 2 and type(None) in args:
|
||||
layer_annotation = args[0] if args[1] is type(None) else args[1]
|
||||
layer_type = _as_layer_type(layer_annotation)
|
||||
if layer_type is None:
|
||||
return None
|
||||
return LayerDepSpec(layer_type=layer_type, optional=True)
|
||||
|
||||
layer_type = _as_layer_type(annotation)
|
||||
if layer_type is None:
|
||||
return None
|
||||
return LayerDepSpec(layer_type=layer_type)
|
||||
|
||||
|
||||
def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any, Any, Any, Any, Any]] | None:
|
||||
runtime_type = get_origin(annotation) or annotation
|
||||
if isinstance(runtime_type, type) and issubclass(runtime_type, Layer):
|
||||
return cast(type[Layer[Any, Any, Any, Any, Any, Any, Any]], runtime_type)
|
||||
return None
|
||||
|
||||
|
||||
def _infer_deps_type(layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]]) -> type[LayerDeps] | None:
|
||||
inferred = _infer_layer_generic_arg(layer_type, 0, {})
|
||||
if inferred is None:
|
||||
return None
|
||||
return _as_deps_type(inferred)
|
||||
|
||||
|
||||
def _infer_schema_type(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]],
|
||||
index: int,
|
||||
attr_name: str,
|
||||
) -> type[BaseModel] | None:
|
||||
inferred = _infer_schema_generic_arg(layer_type, attr_name, {}) or _infer_layer_generic_arg(layer_type, index, {})
|
||||
if inferred is None:
|
||||
return None
|
||||
schema_type = _as_model_type(inferred)
|
||||
if schema_type is None:
|
||||
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
|
||||
return schema_type
|
||||
|
||||
|
||||
def _infer_schema_generic_arg(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]],
|
||||
attr_name: str,
|
||||
substitutions: Mapping[object, object],
|
||||
) -> object | None:
|
||||
"""Infer schema type arguments exposed by typed layer family bases."""
|
||||
expected_names = {
|
||||
"config_type": {"ConfigT", "_ConfigT"},
|
||||
"runtime_state_type": {"RuntimeStateT", "_RuntimeStateT"},
|
||||
"runtime_handles_type": {"RuntimeHandlesT", "_RuntimeHandlesT"},
|
||||
}[attr_name]
|
||||
for base in getattr(layer_type, "__orig_bases__", ()):
|
||||
origin = get_origin(base) or base
|
||||
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
|
||||
if not isinstance(origin, type) or not issubclass(origin, Layer):
|
||||
continue
|
||||
|
||||
params = _generic_params(origin)
|
||||
for param, arg in zip(params, args):
|
||||
if getattr(param, "__name__", None) in expected_names:
|
||||
return arg
|
||||
|
||||
next_substitutions = dict(substitutions)
|
||||
next_substitutions.update(_generic_arg_substitutions(origin, args))
|
||||
inferred = _infer_schema_generic_arg(origin, attr_name, next_substitutions)
|
||||
if inferred is not None:
|
||||
return inferred
|
||||
return None
|
||||
|
||||
|
||||
def _infer_layer_generic_arg(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]],
|
||||
index: int,
|
||||
substitutions: Mapping[object, object],
|
||||
) -> object | None:
|
||||
"""Infer one concrete ``Layer`` generic argument through inheritance.
|
||||
|
||||
This walks through intermediate generic base classes so subclasses can omit
|
||||
explicit class attributes in common cases such as ``class X(Base[YDeps])``.
|
||||
"""
|
||||
for base in getattr(layer_type, "__orig_bases__", ()):
|
||||
origin = get_origin(base) or base
|
||||
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
|
||||
if origin is Layer:
|
||||
if len(args) <= index:
|
||||
continue
|
||||
return args[index]
|
||||
|
||||
if not isinstance(origin, type) or not issubclass(origin, Layer):
|
||||
continue
|
||||
|
||||
next_substitutions = dict(substitutions)
|
||||
next_substitutions.update(_generic_arg_substitutions(origin, args))
|
||||
inferred = _infer_layer_generic_arg(origin, index, next_substitutions)
|
||||
if inferred is not None:
|
||||
return inferred
|
||||
return None
|
||||
|
||||
|
||||
def _init_schema_type(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]],
|
||||
attr_name: str,
|
||||
inferred_schema_type: type[BaseModel] | None,
|
||||
default_schema_type: type[BaseModel],
|
||||
) -> None:
|
||||
schema_type = layer_type.__dict__.get(attr_name)
|
||||
if schema_type is None:
|
||||
schema_type = inferred_schema_type or getattr(layer_type, attr_name, default_schema_type)
|
||||
setattr(layer_type, attr_name, schema_type)
|
||||
if not isinstance(schema_type, type) or not issubclass(schema_type, BaseModel):
|
||||
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
|
||||
|
||||
|
||||
def _substitute_type(value: object, substitutions: Mapping[object, object]) -> object:
|
||||
if value in substitutions:
|
||||
return substitutions[value]
|
||||
|
||||
origin = get_origin(value)
|
||||
if origin is None:
|
||||
return value
|
||||
|
||||
args = get_args(value)
|
||||
if not args:
|
||||
return value
|
||||
|
||||
substituted_args = tuple(_substitute_type(arg, substitutions) for arg in args)
|
||||
if substituted_args == args:
|
||||
return value
|
||||
|
||||
try:
|
||||
return origin[substituted_args]
|
||||
except TypeError:
|
||||
return value
|
||||
|
||||
|
||||
def _generic_arg_substitutions(origin: type[Any], args: Sequence[object]) -> dict[object, object]:
|
||||
params = _generic_params(origin)
|
||||
return dict(zip(params, args))
|
||||
|
||||
|
||||
def _generic_params(origin: type[Any]) -> Sequence[object]:
|
||||
params = getattr(origin, "__type_params__", ())
|
||||
if not params:
|
||||
params = getattr(origin, "__parameters__", ())
|
||||
return params
|
||||
|
||||
|
||||
def _as_deps_type(value: object) -> type[LayerDeps] | None:
|
||||
runtime_type = get_origin(value) or value
|
||||
if isinstance(runtime_type, type) and issubclass(runtime_type, LayerDeps):
|
||||
return runtime_type
|
||||
return None
|
||||
|
||||
|
||||
def _as_model_type(value: object) -> type[BaseModel] | None:
|
||||
runtime_type = get_origin(value) or value
|
||||
if isinstance(runtime_type, type) and issubclass(runtime_type, BaseModel):
|
||||
return runtime_type
|
||||
return None
|
||||
|
||||
|
||||
def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any, Any, Any, Any, Any]]) -> bool:
|
||||
return bool(getattr(layer_type, "__type_params__", ())) or bool(
|
||||
getattr(layer_type, "__parameters__", ())
|
||||
)
|
||||
@ -1,185 +0,0 @@
|
||||
"""Typed layer family definitions.
|
||||
|
||||
``Layer`` itself is framework-neutral. This module defines typed layer families
|
||||
that bind its system prompt, user prompt, and tool generic slots to concrete
|
||||
contracts, such as ordinary strings with plain callable tools or pydantic-ai
|
||||
prompt/tool shapes. The families keep the trailing schema generic slots open so
|
||||
concrete layers can have ``config_type``, ``runtime_state_type``, and
|
||||
``runtime_handles_type`` inferred from type arguments instead of repeated class
|
||||
attributes.
|
||||
Tagged aggregate aliases cover code paths that can accept any supported
|
||||
prompt/tool family without changing the plain and pydantic-ai layer contracts.
|
||||
Pydantic-ai names are imported for static analysis only, so ``agenton`` can be
|
||||
imported without loading that optional integration at runtime.
|
||||
Concrete reusable layers live under ``agenton_collections``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal
|
||||
|
||||
from typing_extensions import TypeVar, final, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_ai import Tool
|
||||
from pydantic_ai.messages import UserContent
|
||||
from pydantic_ai.tools import SystemPromptFunc
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agenton.layers.base import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, Layer, LayerDeps
|
||||
|
||||
type PlainPrompt = str
|
||||
type PlainUserPrompt = str
|
||||
type PlainTool = Callable[..., Any]
|
||||
|
||||
|
||||
type PydanticAIPrompt[AgentDepsT] = SystemPromptFunc[AgentDepsT]
|
||||
type PydanticAIUserPrompt = UserContent
|
||||
type PydanticAITool[AgentDepsT] = Tool[AgentDepsT]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PlainPromptType:
|
||||
"""Tagged plain prompt item for aggregate prompt transformations."""
|
||||
|
||||
value: PlainPrompt
|
||||
kind: Literal["plain"] = field(default="plain", init=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PlainToolType:
|
||||
"""Tagged plain tool item for aggregate tool transformations."""
|
||||
|
||||
value: PlainTool
|
||||
kind: Literal["plain"] = field(default="plain", init=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PlainUserPromptType:
|
||||
"""Tagged plain user prompt item for aggregate user prompt transformations."""
|
||||
|
||||
value: PlainUserPrompt
|
||||
kind: Literal["plain"] = field(default="plain", init=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PydanticAIPromptType[AgentDepsT]:
|
||||
"""Tagged pydantic-ai prompt item for aggregate prompt transformations."""
|
||||
|
||||
value: PydanticAIPrompt[AgentDepsT]
|
||||
kind: Literal["pydantic_ai"] = field(default="pydantic_ai", init=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PydanticAIUserPromptType:
|
||||
"""Tagged pydantic-ai user prompt item for aggregate user prompts."""
|
||||
|
||||
value: PydanticAIUserPrompt
|
||||
kind: Literal["pydantic_ai"] = field(default="pydantic_ai", init=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PydanticAIToolType[AgentDepsT]:
|
||||
"""Tagged pydantic-ai tool item for aggregate tool transformations."""
|
||||
|
||||
value: PydanticAITool[AgentDepsT]
|
||||
kind: Literal["pydantic_ai"] = field(default="pydantic_ai", init=False)
|
||||
|
||||
|
||||
type AllPromptTypes = PlainPromptType | PydanticAIPromptType[Any]
|
||||
type AllUserPromptTypes = PlainUserPromptType | PydanticAIUserPromptType
|
||||
type AllToolTypes = PlainToolType | PydanticAIToolType[Any]
|
||||
|
||||
|
||||
_DepsT = TypeVar("_DepsT", bound=LayerDeps)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default=EmptyLayerConfig)
|
||||
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default=EmptyRuntimeState)
|
||||
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default=EmptyRuntimeHandles)
|
||||
_AgentDepsT = TypeVar("_AgentDepsT")
|
||||
|
||||
|
||||
class PlainLayer(
|
||||
Generic[_DepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
Layer[
|
||||
_DepsT,
|
||||
PlainPrompt,
|
||||
PlainUserPrompt,
|
||||
PlainTool,
|
||||
_ConfigT,
|
||||
_RuntimeStateT,
|
||||
_RuntimeHandlesT,
|
||||
],
|
||||
):
|
||||
"""Layer base for ordinary string prompts and plain-callable tools."""
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_prompt(self, prompt: PlainPrompt) -> PlainPromptType:
|
||||
return PlainPromptType(prompt)
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_user_prompt(self, prompt: PlainUserPrompt) -> PlainUserPromptType:
|
||||
return PlainUserPromptType(prompt)
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_tool(self, tool: PlainTool) -> PlainToolType:
|
||||
return PlainToolType(tool)
|
||||
|
||||
|
||||
class PydanticAILayer(
|
||||
Generic[_DepsT, _AgentDepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
Layer[
|
||||
_DepsT,
|
||||
PydanticAIPrompt[_AgentDepsT],
|
||||
PydanticAIUserPrompt,
|
||||
PydanticAITool[_AgentDepsT],
|
||||
_ConfigT,
|
||||
_RuntimeStateT,
|
||||
_RuntimeHandlesT,
|
||||
],
|
||||
):
|
||||
"""Layer base for pydantic-ai prompt and tool adapters."""
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_prompt(
|
||||
self,
|
||||
prompt: PydanticAIPrompt[_AgentDepsT],
|
||||
) -> PydanticAIPromptType[_AgentDepsT]:
|
||||
return PydanticAIPromptType(prompt)
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_user_prompt(self, prompt: PydanticAIUserPrompt) -> PydanticAIUserPromptType:
|
||||
return PydanticAIUserPromptType(prompt)
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_tool(self, tool: PydanticAITool[_AgentDepsT]) -> PydanticAIToolType[_AgentDepsT]:
|
||||
return PydanticAIToolType(tool)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AllPromptTypes",
|
||||
"AllUserPromptTypes",
|
||||
"AllToolTypes",
|
||||
"PlainLayer",
|
||||
"PlainPrompt",
|
||||
"PlainPromptType",
|
||||
"PlainUserPrompt",
|
||||
"PlainUserPromptType",
|
||||
"PlainTool",
|
||||
"PlainToolType",
|
||||
"PydanticAILayer",
|
||||
"PydanticAIPrompt",
|
||||
"PydanticAIPromptType",
|
||||
"PydanticAIUserPrompt",
|
||||
"PydanticAIUserPromptType",
|
||||
"PydanticAITool",
|
||||
"PydanticAIToolType",
|
||||
]
|
||||
@ -1,57 +0,0 @@
|
||||
"""Convenience exports for reusable layer implementations.
|
||||
|
||||
Concrete collection layers live in family subpackages such as
|
||||
``agenton_collections.plain`` and ``agenton_collections.pydantic_ai``. The
|
||||
package root keeps the short import path for common layers while avoiding
|
||||
implementation code in ``__init__``.
|
||||
"""
|
||||
|
||||
from agenton.layers.types import (
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
PlainLayer,
|
||||
PlainPrompt,
|
||||
PlainPromptType,
|
||||
PlainTool,
|
||||
PlainToolType,
|
||||
PydanticAILayer,
|
||||
PydanticAIPrompt,
|
||||
PydanticAIPromptType,
|
||||
PydanticAITool,
|
||||
PydanticAIToolType,
|
||||
)
|
||||
from agenton_collections.layers.pydantic_ai import (
|
||||
PydanticAIBridgeLayer,
|
||||
PydanticAIBridgeLayerDeps,
|
||||
)
|
||||
from agenton_collections.layers.plain import (
|
||||
DynamicToolsLayer,
|
||||
DynamicToolsLayerDeps,
|
||||
ObjectLayer,
|
||||
PromptLayer,
|
||||
ToolsLayer,
|
||||
with_object,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllPromptTypes",
|
||||
"AllToolTypes",
|
||||
"DynamicToolsLayer",
|
||||
"DynamicToolsLayerDeps",
|
||||
"ObjectLayer",
|
||||
"PlainLayer",
|
||||
"PlainPrompt",
|
||||
"PlainPromptType",
|
||||
"PlainTool",
|
||||
"PlainToolType",
|
||||
"PromptLayer",
|
||||
"PydanticAIBridgeLayer",
|
||||
"PydanticAIBridgeLayerDeps",
|
||||
"PydanticAILayer",
|
||||
"PydanticAIPrompt",
|
||||
"PydanticAIPromptType",
|
||||
"PydanticAITool",
|
||||
"PydanticAIToolType",
|
||||
"ToolsLayer",
|
||||
"with_object",
|
||||
]
|
||||
@ -1,18 +0,0 @@
|
||||
"""Reusable collection layers for the plain layer family."""
|
||||
|
||||
from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, PromptLayerConfig, ToolsLayer
|
||||
from agenton_collections.layers.plain.dynamic_tools import (
|
||||
DynamicToolsLayer,
|
||||
DynamicToolsLayerDeps,
|
||||
with_object,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DynamicToolsLayer",
|
||||
"DynamicToolsLayerDeps",
|
||||
"ObjectLayer",
|
||||
"PromptLayer",
|
||||
"PromptLayerConfig",
|
||||
"ToolsLayer",
|
||||
"with_object",
|
||||
]
|
||||
@ -1,95 +0,0 @@
|
||||
"""Basic ready-to-compose layers for common plain use cases.
|
||||
|
||||
These layers are small concrete implementations built on
|
||||
``agenton.layers.types``. They intentionally stay free of compositor graph
|
||||
construction so they can be reused from config, examples, and higher-level
|
||||
dynamic layers.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from agenton.layers.base import NoLayerDeps
|
||||
from agenton.layers.types import PlainLayer
|
||||
|
||||
|
||||
class PromptLayerConfig(BaseModel):
|
||||
"""Serializable config schema for ``PromptLayer``."""
|
||||
|
||||
prefix: list[str] | str = Field(default_factory=list)
|
||||
user: list[str] | str = Field(default_factory=list)
|
||||
suffix: list[str] | str = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectLayer[ObjectT](PlainLayer[NoLayerDeps]):
|
||||
"""Layer that stores one typed object for downstream dependencies.
|
||||
|
||||
Object layers are instance-only because arbitrary Python objects are not
|
||||
serializable graph config. Add them with ``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
value: ObjectT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptLayer(PlainLayer[NoLayerDeps, PromptLayerConfig]):
|
||||
"""Layer that contributes configured system and user prompt fragments."""
|
||||
|
||||
type_id = "plain.prompt"
|
||||
|
||||
prefix: list[str] | str = field(default_factory=list)
|
||||
user: list[str] | str = field(default_factory=list)
|
||||
suffix: list[str] | str = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BaseModel):
|
||||
"""Create a prompt layer from validated prompt config."""
|
||||
validated_config = PromptLayerConfig.model_validate(config)
|
||||
return cls(prefix=validated_config.prefix, user=validated_config.user, suffix=validated_config.suffix)
|
||||
|
||||
@property
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
if isinstance(self.prefix, str):
|
||||
return [self.prefix]
|
||||
return self.prefix
|
||||
|
||||
@property
|
||||
def suffix_prompts(self) -> list[str]:
|
||||
if isinstance(self.suffix, str):
|
||||
return [self.suffix]
|
||||
return self.suffix
|
||||
|
||||
@property
|
||||
def user_prompts(self) -> list[str]:
|
||||
if isinstance(self.user, str):
|
||||
return [self.user]
|
||||
return self.user
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsLayer(PlainLayer[NoLayerDeps]):
|
||||
"""Layer that contributes configured plain-callable tools.
|
||||
|
||||
Tool layers are instance-only because Python callables are live objects. Add
|
||||
them with ``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
tool_entries: Sequence[Callable[..., Any]] = ()
|
||||
|
||||
@property
|
||||
def tools(self) -> list[Callable[..., Any]]:
|
||||
return list(self.tool_entries)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ObjectLayer",
|
||||
"PromptLayerConfig",
|
||||
"PromptLayer",
|
||||
"ToolsLayer",
|
||||
]
|
||||
@ -1,233 +0,0 @@
|
||||
"""Dynamic plain-tool layer with object-bound tool entries.
|
||||
|
||||
This module builds on ``ObjectLayer`` from ``agenton_collections.plain.basic``.
|
||||
Plain callables are exposed unchanged, while entries wrapped with
|
||||
``with_object`` bind the current object value into the first callable argument
|
||||
and expose the remaining parameters as the public tool signature.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from inspect import Parameter, Signature, iscoroutinefunction, signature
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Concatenate,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from agenton.layers.base import LayerDeps
|
||||
from agenton.layers.types import PlainLayer
|
||||
from agenton_collections.layers.plain.basic import ObjectLayer
|
||||
|
||||
type _ObjectToolCallable[ObjectT] = Callable[Concatenate[ObjectT, ...], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _ObjectToolEntry[ObjectT]:
|
||||
"""Tool entry whose first argument should be filled from ``ObjectLayer``."""
|
||||
|
||||
tool_entry: _ObjectToolCallable[ObjectT]
|
||||
object_type: type[ObjectT] | None = None
|
||||
|
||||
|
||||
type _DynamicToolEntry[ObjectT] = Callable[..., Any] | _ObjectToolEntry[ObjectT]
|
||||
|
||||
|
||||
def with_object[ObjectT](
|
||||
object_type: type[ObjectT],
|
||||
/,
|
||||
) -> Callable[[_ObjectToolCallable[ObjectT]], _ObjectToolEntry[ObjectT]]:
|
||||
"""Mark a tool as requiring the bound object value as its first argument."""
|
||||
def decorator(tool_entry: _ObjectToolCallable[ObjectT]) -> _ObjectToolEntry[ObjectT]:
|
||||
_validate_object_tool_annotation(tool_entry, object_type)
|
||||
return _ObjectToolEntry(tool_entry=tool_entry, object_type=object_type)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class DynamicToolsLayerDeps[ObjectT](LayerDeps):
|
||||
"""Dependencies required by ``DynamicToolsLayer``."""
|
||||
|
||||
object_layer: ObjectLayer[ObjectT] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicToolsLayer[ObjectT](PlainLayer[DynamicToolsLayerDeps[ObjectT]]):
|
||||
"""Layer that exposes plain tools and object-bound tools."""
|
||||
|
||||
tool_entries: Sequence[_DynamicToolEntry[ObjectT]] = ()
|
||||
|
||||
@property
|
||||
def tools(self) -> list[Callable[..., Any]]:
|
||||
object_value = self.deps.object_layer.value
|
||||
return [
|
||||
_bind_object_argument(tool_entry.tool_entry, object_value, tool_entry.object_type)
|
||||
if isinstance(tool_entry, _ObjectToolEntry)
|
||||
else tool_entry
|
||||
for tool_entry in self.tool_entries
|
||||
]
|
||||
|
||||
|
||||
def _bind_object_argument[ObjectT](
|
||||
tool_entry: _ObjectToolCallable[ObjectT],
|
||||
object_value: ObjectT,
|
||||
object_type: type[ObjectT] | None,
|
||||
) -> Callable[..., Any]:
|
||||
_validate_object_value(tool_entry, object_value, object_type)
|
||||
if iscoroutinefunction(tool_entry):
|
||||
wrapped = _async_object_wrapper(tool_entry, object_value)
|
||||
else:
|
||||
wrapped = _sync_object_wrapper(tool_entry, object_value)
|
||||
|
||||
public_signature = _public_tool_signature(tool_entry)
|
||||
if public_signature is not None:
|
||||
setattr(wrapped, "__signature__", public_signature)
|
||||
_set_public_annotations(wrapped, tool_entry)
|
||||
return wrapped
|
||||
|
||||
|
||||
def _validate_object_tool_annotation[ObjectT](
|
||||
tool_entry: _ObjectToolCallable[ObjectT],
|
||||
object_type: type[ObjectT],
|
||||
) -> None:
|
||||
parameter = _first_object_parameter(tool_entry)
|
||||
if parameter is None:
|
||||
return
|
||||
|
||||
annotation = _parameter_annotation(tool_entry, parameter)
|
||||
if annotation is Parameter.empty:
|
||||
return
|
||||
if _annotation_accepts_object_type(annotation, object_type):
|
||||
return
|
||||
|
||||
raise TypeError(
|
||||
f"Object-bound tool '{_tool_name(tool_entry)}' first parameter should accept "
|
||||
f"'{_type_name(object_type)}'."
|
||||
)
|
||||
|
||||
|
||||
def _first_object_parameter(tool_entry: Callable[..., Any]) -> Parameter | None:
|
||||
try:
|
||||
tool_signature = signature(tool_entry)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
parameters = list(tool_signature.parameters.values())
|
||||
if not parameters:
|
||||
raise ValueError("Dynamic tools must accept the object dependency as their first parameter.")
|
||||
return parameters[0]
|
||||
|
||||
|
||||
def _parameter_annotation(tool_entry: Callable[..., Any], parameter: Parameter) -> object:
|
||||
try:
|
||||
type_hints = get_type_hints(tool_entry, include_extras=True)
|
||||
except (AttributeError, NameError, TypeError):
|
||||
return parameter.annotation
|
||||
return type_hints.get(parameter.name, parameter.annotation)
|
||||
|
||||
|
||||
def _annotation_accepts_object_type(annotation: object, object_type: type[Any]) -> bool:
|
||||
if annotation is Any or annotation is Parameter.empty:
|
||||
return True
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is Annotated:
|
||||
args = get_args(annotation)
|
||||
return True if not args else _annotation_accepts_object_type(args[0], object_type)
|
||||
if origin in (UnionType, Union):
|
||||
return any(
|
||||
arg is type(None) or _annotation_accepts_object_type(arg, object_type)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
|
||||
runtime_type = origin or annotation
|
||||
if not isinstance(runtime_type, type):
|
||||
return True
|
||||
try:
|
||||
return issubclass(object_type, runtime_type)
|
||||
except TypeError:
|
||||
return True
|
||||
|
||||
|
||||
def _validate_object_value[ObjectT](
|
||||
tool_entry: _ObjectToolCallable[ObjectT],
|
||||
object_value: ObjectT,
|
||||
object_type: type[ObjectT] | None,
|
||||
) -> None:
|
||||
if object_type is None or isinstance(object_value, object_type):
|
||||
return
|
||||
raise TypeError(
|
||||
f"Object-bound tool '{_tool_name(tool_entry)}' expected object dependency "
|
||||
f"of type '{_type_name(object_type)}', but got '{type(object_value).__qualname__}'."
|
||||
)
|
||||
|
||||
|
||||
def _tool_name(tool_entry: Callable[..., Any]) -> str:
|
||||
return getattr(tool_entry, "__qualname__", getattr(tool_entry, "__name__", repr(tool_entry)))
|
||||
|
||||
|
||||
def _type_name(object_type: type[Any]) -> str:
|
||||
return object_type.__qualname__
|
||||
|
||||
|
||||
def _sync_object_wrapper[ObjectT](
|
||||
tool_entry: _ObjectToolCallable[ObjectT],
|
||||
object_value: ObjectT,
|
||||
) -> Callable[..., Any]:
|
||||
@wraps(tool_entry)
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
return tool_entry(object_value, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _async_object_wrapper[ObjectT](
|
||||
tool_entry: _ObjectToolCallable[ObjectT],
|
||||
object_value: ObjectT,
|
||||
) -> Callable[..., Any]:
|
||||
@wraps(tool_entry)
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
return await tool_entry(object_value, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _public_tool_signature(tool_entry: Callable[..., Any]) -> Signature | None:
|
||||
try:
|
||||
tool_signature = signature(tool_entry)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
parameters = list(tool_signature.parameters.values())
|
||||
if not parameters:
|
||||
raise ValueError("Dynamic tools must accept the object dependency as their first parameter.")
|
||||
return tool_signature.replace(parameters=parameters[1:])
|
||||
|
||||
|
||||
def _set_public_annotations(wrapper: Callable[..., Any], tool_entry: Callable[..., Any]) -> None:
|
||||
annotations = getattr(tool_entry, "__annotations__", None)
|
||||
if not isinstance(annotations, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
parameters = list(signature(tool_entry).parameters)
|
||||
except (TypeError, ValueError):
|
||||
parameters = []
|
||||
|
||||
public_annotations = dict(annotations)
|
||||
if parameters:
|
||||
public_annotations.pop(parameters[0], None)
|
||||
wrapper.__annotations__ = public_annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DynamicToolsLayer",
|
||||
"DynamicToolsLayerDeps",
|
||||
"with_object",
|
||||
]
|
||||
@ -1,11 +0,0 @@
|
||||
"""Reusable collection layers for the pydantic-ai layer family."""
|
||||
|
||||
from agenton_collections.layers.pydantic_ai.bridge import (
|
||||
PydanticAIBridgeLayer,
|
||||
PydanticAIBridgeLayerDeps,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PydanticAIBridgeLayer",
|
||||
"PydanticAIBridgeLayerDeps",
|
||||
]
|
||||
@ -1,108 +0,0 @@
|
||||
"""Pydantic AI bridge prompt and tool layer.
|
||||
|
||||
This module keeps pydantic-ai's callable shapes intact through
|
||||
``PydanticAILayer``. The bridge layer depends on ``ObjectLayer`` so callers have
|
||||
one explicit graph node that provides the object used as
|
||||
``RunContext[ObjectT].deps`` in pydantic-ai prompt and tool callables.
|
||||
Bridge construction accepts pydantic-ai's ergonomic input forms and normalizes
|
||||
them at the layer boundary: string system prompts become zero-arg system prompt
|
||||
functions, user prompts stay as pydantic-ai ``UserContent`` values, and bare
|
||||
tool functions become ``Tool`` instances.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic_ai import Tool
|
||||
from pydantic_ai.messages import UserContent
|
||||
from pydantic_ai.tools import ToolFuncEither
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.layers.base import LayerDeps
|
||||
from agenton.layers.types import PydanticAILayer, PydanticAIPrompt, PydanticAITool, PydanticAIUserPrompt
|
||||
from agenton_collections.layers.plain.basic import ObjectLayer
|
||||
|
||||
|
||||
class PydanticAIBridgeLayerDeps[ObjectT](LayerDeps):
|
||||
"""Dependencies required by ``PydanticAIBridgeLayer``."""
|
||||
|
||||
object_layer: ObjectLayer[ObjectT] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticAIBridgeLayer[ObjectT](
|
||||
PydanticAILayer[PydanticAIBridgeLayerDeps[ObjectT], ObjectT]
|
||||
):
|
||||
"""Bridge layer for pydantic-ai prompts and tools using one object deps."""
|
||||
|
||||
prefix: str | PydanticAIPrompt[ObjectT] | Sequence[str | PydanticAIPrompt[ObjectT]] = ()
|
||||
user: UserContent | Sequence[UserContent] = ()
|
||||
suffix: str | PydanticAIPrompt[ObjectT] | Sequence[str | PydanticAIPrompt[ObjectT]] = ()
|
||||
tool_entries: Sequence[PydanticAITool[ObjectT] | ToolFuncEither[ObjectT, ...]] = ()
|
||||
|
||||
@property
|
||||
def run_deps(self) -> ObjectT:
|
||||
"""Object to pass as pydantic-ai run deps for this layer."""
|
||||
return self.deps.object_layer.value
|
||||
|
||||
@property
|
||||
@override
|
||||
def prefix_prompts(self) -> list[PydanticAIPrompt[ObjectT]]:
|
||||
return _normalize_prompts(self.prefix)
|
||||
|
||||
@property
|
||||
@override
|
||||
def suffix_prompts(self) -> list[PydanticAIPrompt[ObjectT]]:
|
||||
return _normalize_prompts(self.suffix)
|
||||
|
||||
@property
|
||||
@override
|
||||
def user_prompts(self) -> list[PydanticAIUserPrompt]:
|
||||
return _normalize_user_prompts(self.user)
|
||||
|
||||
@property
|
||||
@override
|
||||
def tools(self) -> list[PydanticAITool[ObjectT]]:
|
||||
return [_normalize_tool(tool_entry) for tool_entry in self.tool_entries]
|
||||
|
||||
|
||||
def _normalize_prompts[ObjectT](
|
||||
prompts: str | PydanticAIPrompt[ObjectT] | Sequence[str | PydanticAIPrompt[ObjectT]],
|
||||
) -> list[PydanticAIPrompt[ObjectT]]:
|
||||
if isinstance(prompts, str):
|
||||
return [_normalize_prompt(prompts)]
|
||||
if isinstance(prompts, Sequence):
|
||||
return [_normalize_prompt(prompt) for prompt in prompts]
|
||||
return [prompts]
|
||||
|
||||
|
||||
def _normalize_prompt[ObjectT](
|
||||
prompt: str | PydanticAIPrompt[ObjectT],
|
||||
) -> PydanticAIPrompt[ObjectT]:
|
||||
if isinstance(prompt, str):
|
||||
return (lambda value: lambda: value)(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
def _normalize_user_prompts(
|
||||
prompts: UserContent | Sequence[UserContent],
|
||||
) -> list[PydanticAIUserPrompt]:
|
||||
if isinstance(prompts, str):
|
||||
return [prompts]
|
||||
if isinstance(prompts, Sequence):
|
||||
return list(prompts)
|
||||
return [prompts]
|
||||
|
||||
|
||||
def _normalize_tool[ObjectT](
|
||||
tool_entry: PydanticAITool[ObjectT] | ToolFuncEither[ObjectT, ...],
|
||||
) -> PydanticAITool[ObjectT]:
|
||||
if isinstance(tool_entry, Tool):
|
||||
return tool_entry
|
||||
return Tool(tool_entry)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PydanticAIBridgeLayer",
|
||||
"PydanticAIBridgeLayerDeps",
|
||||
]
|
||||
@ -1,11 +0,0 @@
|
||||
"""Reusable compositor transformers for collection integrations."""
|
||||
|
||||
from agenton_collections.transformers.pydantic_ai import (
|
||||
PYDANTIC_AI_TRANSFORMERS,
|
||||
PydanticAICompositorTransformerKwargs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PYDANTIC_AI_TRANSFORMERS",
|
||||
"PydanticAICompositorTransformerKwargs",
|
||||
]
|
||||
@ -1,85 +0,0 @@
|
||||
"""Pydantic AI compositor transformer presets.
|
||||
|
||||
This module owns the pydantic-ai runtime dependency for transforming tagged
|
||||
agenton system prompt, user prompt, and tool items into pydantic-ai-compatible
|
||||
items.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Final
|
||||
|
||||
from pydantic_ai import Tool
|
||||
|
||||
from agenton.compositor import CompositorTransformerKwargs
|
||||
from agenton.layers.types import (
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
AllUserPromptTypes,
|
||||
PydanticAIPrompt,
|
||||
PydanticAITool,
|
||||
PydanticAIUserPrompt,
|
||||
)
|
||||
|
||||
type PydanticAICompositorTransformerKwargs = CompositorTransformerKwargs[
|
||||
PydanticAIPrompt[object],
|
||||
PydanticAITool[object],
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
PydanticAIUserPrompt,
|
||||
AllUserPromptTypes,
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_ai_prompt_transformer(
|
||||
prompts: Sequence[AllPromptTypes],
|
||||
) -> list[PydanticAIPrompt[object]]:
|
||||
result: list[PydanticAIPrompt[object]] = []
|
||||
for prompt in prompts:
|
||||
if prompt.kind == "plain":
|
||||
result.append((lambda value: lambda: value)(prompt.value))
|
||||
elif prompt.kind == "pydantic_ai":
|
||||
result.append(prompt.value)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported prompt type: {type(prompt).__qualname__}.")
|
||||
return result
|
||||
|
||||
|
||||
def _pydantic_ai_user_prompt_transformer(
|
||||
prompts: Sequence[AllUserPromptTypes],
|
||||
) -> list[PydanticAIUserPrompt]:
|
||||
result: list[PydanticAIUserPrompt] = []
|
||||
for prompt in prompts:
|
||||
if prompt.kind == "plain":
|
||||
result.append(prompt.value)
|
||||
elif prompt.kind == "pydantic_ai":
|
||||
result.append(prompt.value)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported user prompt type: {type(prompt).__qualname__}.")
|
||||
return result
|
||||
|
||||
|
||||
def _pydantic_ai_tool_transformer(
|
||||
tools: Sequence[AllToolTypes],
|
||||
) -> list[PydanticAITool[object]]:
|
||||
result: list[PydanticAITool[object]] = []
|
||||
for tool in tools:
|
||||
if tool.kind == "plain":
|
||||
result.append(Tool(tool.value))
|
||||
elif tool.kind == "pydantic_ai":
|
||||
result.append(tool.value)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported tool type: {type(tool).__qualname__}.")
|
||||
return result
|
||||
|
||||
|
||||
PYDANTIC_AI_TRANSFORMERS: Final[PydanticAICompositorTransformerKwargs] = {
|
||||
"prompt_transformer": _pydantic_ai_prompt_transformer,
|
||||
"user_prompt_transformer": _pydantic_ai_user_prompt_transformer,
|
||||
"tool_transformer": _pydantic_ai_tool_transformer,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PYDANTIC_AI_TRANSFORMERS",
|
||||
"PydanticAICompositorTransformerKwargs",
|
||||
]
|
||||
@ -1,5 +0,0 @@
|
||||
"""Adapters for using Dify components inside the local agent package."""
|
||||
|
||||
from .adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider
|
||||
|
||||
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]
|
||||
@ -1 +0,0 @@
|
||||
"""Adapter integrations for Dify agent components."""
|
||||
@ -1,6 +0,0 @@
|
||||
"""LLM adapters for Dify plugin-daemon integrations."""
|
||||
|
||||
from .model import DifyLLMAdapterModel
|
||||
from .provider import DifyPluginDaemonProvider
|
||||
|
||||
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]
|
||||
@ -1,798 +0,0 @@
|
||||
"""Bridge Dify plugin-daemon LLM invocations into Pydantic AI's model interface.
|
||||
|
||||
The API and agent layers are clients of the plugin daemon, not direct hosts of provider SDK
|
||||
implementations. This adapter therefore targets the plugin-daemon dispatch protocol and maps
|
||||
Pydantic AI messages into the daemon's Graphon-compatible request and stream response schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import KW_ONLY, InitVar, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from pydantic_ai._parts_manager import ModelResponsePartsManager
|
||||
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
||||
from pydantic_ai.messages import (
|
||||
AudioUrl,
|
||||
BinaryContent,
|
||||
BuiltinToolCallPart,
|
||||
BuiltinToolReturnPart,
|
||||
CachePoint,
|
||||
CompactionPart,
|
||||
DocumentUrl,
|
||||
FilePart,
|
||||
FinishReason,
|
||||
ImageUrl,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ModelResponsePart,
|
||||
ModelResponseStreamEvent,
|
||||
MultiModalContent,
|
||||
RetryPromptPart,
|
||||
SystemPromptPart,
|
||||
TextContent,
|
||||
TextPart,
|
||||
ThinkingPart,
|
||||
ToolCallPart,
|
||||
ToolReturnPart,
|
||||
UploadedFile,
|
||||
UserContent,
|
||||
UserPromptPart,
|
||||
VideoUrl,
|
||||
)
|
||||
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
||||
from pydantic_ai.profiles import ModelProfileSpec
|
||||
from pydantic_ai.settings import ModelSettings
|
||||
from pydantic_ai.usage import RequestUsage
|
||||
|
||||
from .provider import DifyPluginDaemonLLMClient, DifyPluginDaemonProvider
|
||||
|
||||
_THINK_START = "<think>\n"
|
||||
_THINK_END = "\n</think>"
|
||||
_THINK_OPEN_TAG = "<think>"
|
||||
_THINK_CLOSE_TAG = "</think>"
|
||||
_THINK_TAG_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
_DETAIL_HIGH = "high"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _DifyRequestInput:
|
||||
credentials: dict[str, object]
|
||||
prompt_messages: list[PromptMessage]
|
||||
model_parameters: dict[str, object]
|
||||
tools: list[PromptMessageTool] | None
|
||||
stop_sequences: list[str] | None
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyLLMAdapterModel(Model[DifyPluginDaemonLLMClient]):
|
||||
"""Use a Dify plugin-daemon LLM provider as a Pydantic AI model."""
|
||||
|
||||
model: str
|
||||
daemon_provider: DifyPluginDaemonProvider
|
||||
_: KW_ONLY
|
||||
credentials: dict[str, object] = field(default_factory=dict, repr=False)
|
||||
model_profile: InitVar[ModelProfileSpec | None] = None
|
||||
model_settings: InitVar[ModelSettings | None] = None
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
model_profile: ModelProfileSpec | None,
|
||||
model_settings: ModelSettings | None,
|
||||
) -> None:
|
||||
Model.__init__(
|
||||
self,
|
||||
settings=model_settings,
|
||||
profile=model_profile or self.daemon_provider.model_profile(self.model),
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider(self) -> DifyPluginDaemonProvider:
|
||||
return self.daemon_provider
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
@property
|
||||
@override
|
||||
def system(self) -> str:
|
||||
return self.daemon_provider.name
|
||||
|
||||
@override
|
||||
async def request(
|
||||
self,
|
||||
messages: list[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> ModelResponse:
|
||||
prepared_settings, prepared_params = self.prepare_request(
|
||||
model_settings, model_request_parameters
|
||||
)
|
||||
request_input = self._build_request_input(
|
||||
messages, prepared_settings, prepared_params
|
||||
)
|
||||
|
||||
response = DifyStreamedResponse(
|
||||
model_request_parameters=prepared_params,
|
||||
chunks=self.daemon_provider.client.iter_llm_result_chunks(
|
||||
model=self.model_name,
|
||||
credentials=request_input.credentials,
|
||||
prompt_messages=request_input.prompt_messages,
|
||||
model_parameters=request_input.model_parameters,
|
||||
tools=request_input.tools,
|
||||
stop=request_input.stop_sequences,
|
||||
stream=False,
|
||||
),
|
||||
response_model_name=self.model_name,
|
||||
provider_name_value=self.system,
|
||||
)
|
||||
async for _event in response:
|
||||
pass
|
||||
return response.get()
|
||||
|
||||
@asynccontextmanager
|
||||
@override
|
||||
async def request_stream(
|
||||
self,
|
||||
messages: list[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
run_context: object | None = None,
|
||||
) -> AsyncGenerator[StreamedResponse, None]:
|
||||
del run_context
|
||||
prepared_settings, prepared_params = self.prepare_request(
|
||||
model_settings, model_request_parameters
|
||||
)
|
||||
request_input = self._build_request_input(
|
||||
messages, prepared_settings, prepared_params
|
||||
)
|
||||
|
||||
yield DifyStreamedResponse(
|
||||
model_request_parameters=prepared_params,
|
||||
chunks=self.daemon_provider.client.iter_llm_result_chunks(
|
||||
model=self.model_name,
|
||||
credentials=request_input.credentials,
|
||||
prompt_messages=request_input.prompt_messages,
|
||||
model_parameters=request_input.model_parameters,
|
||||
tools=request_input.tools,
|
||||
stop=request_input.stop_sequences,
|
||||
stream=True,
|
||||
),
|
||||
response_model_name=self.model_name,
|
||||
provider_name_value=self.system,
|
||||
)
|
||||
|
||||
def _build_request_input(
|
||||
self,
|
||||
messages: Sequence[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> _DifyRequestInput:
|
||||
return _DifyRequestInput(
|
||||
credentials=dict(self.credentials),
|
||||
prompt_messages=_map_messages_to_prompt_messages(
|
||||
messages, model_request_parameters
|
||||
),
|
||||
model_parameters=_map_model_settings_to_parameters(model_settings),
|
||||
tools=_map_tool_definitions_to_prompt_tools(model_request_parameters),
|
||||
stop_sequences=_get_stop_sequences(model_settings),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DifyStreamedResponse(StreamedResponse):
|
||||
chunks: AsyncIterator[LLMResultChunk]
|
||||
response_model_name: str
|
||||
provider_name_value: str
|
||||
_timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
_embedded_thinking_parser: "_EmbeddedThinkingParser" = field(
|
||||
default_factory=lambda: _EmbeddedThinkingParser()
|
||||
)
|
||||
|
||||
@override
|
||||
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
||||
async for chunk in self.chunks:
|
||||
if chunk.delta.usage is not None:
|
||||
self._usage: RequestUsage = _map_usage(chunk.delta.usage)
|
||||
if chunk.delta.finish_reason is not None:
|
||||
self.finish_reason: FinishReason | None = _normalize_finish_reason(
|
||||
chunk.delta.finish_reason
|
||||
)
|
||||
|
||||
for event in _chunk_to_stream_events(
|
||||
self._parts_manager,
|
||||
chunk,
|
||||
self.provider_name_value,
|
||||
self._embedded_thinking_parser,
|
||||
):
|
||||
yield event
|
||||
|
||||
for event in self._embedded_thinking_parser.flush(
|
||||
self._parts_manager, self.provider_name_value
|
||||
):
|
||||
yield event
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_name(self) -> str:
|
||||
return self.response_model_name
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_name(self) -> str:
|
||||
return self.provider_name_value
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_url(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
@override
|
||||
def timestamp(self) -> datetime:
|
||||
return self._timestamp
|
||||
|
||||
|
||||
def _map_messages_to_prompt_messages(
|
||||
messages: Sequence[ModelMessage],
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, ModelRequest):
|
||||
prompt_messages.extend(_map_model_request_to_prompt_messages(message))
|
||||
elif isinstance(message, ModelResponse):
|
||||
assistant_message = _map_model_response_to_prompt_message(message)
|
||||
if assistant_message is not None:
|
||||
prompt_messages.append(assistant_message)
|
||||
else:
|
||||
assert_never(message)
|
||||
|
||||
instruction_messages = [
|
||||
SystemPromptMessage(content=part.content)
|
||||
for part in (
|
||||
Model._get_instruction_parts(messages, model_request_parameters) or []
|
||||
)
|
||||
if part.content.strip()
|
||||
]
|
||||
if instruction_messages:
|
||||
insert_at = next(
|
||||
(
|
||||
index
|
||||
for index, message in enumerate(prompt_messages)
|
||||
if not isinstance(message, SystemPromptMessage)
|
||||
),
|
||||
len(prompt_messages),
|
||||
)
|
||||
prompt_messages[insert_at:insert_at] = instruction_messages
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _map_model_request_to_prompt_messages(message: ModelRequest) -> list[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
for part in message.parts:
|
||||
if isinstance(part, SystemPromptPart):
|
||||
prompt_messages.append(SystemPromptMessage(content=part.content))
|
||||
elif isinstance(part, UserPromptPart):
|
||||
prompt_messages.append(
|
||||
UserPromptMessage(content=_map_user_prompt_content(part.content))
|
||||
)
|
||||
elif isinstance(part, ToolReturnPart):
|
||||
prompt_messages.append(_map_tool_return_part_to_prompt_message(part))
|
||||
elif isinstance(part, RetryPromptPart):
|
||||
if part.tool_name is None:
|
||||
prompt_messages.append(UserPromptMessage(content=part.model_response()))
|
||||
else:
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=part.model_response(),
|
||||
tool_call_id=part.tool_call_id,
|
||||
name=part.tool_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert_never(part)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _map_tool_return_part_to_prompt_message(part: ToolReturnPart) -> ToolPromptMessage:
|
||||
items = part.content_items(mode="str")
|
||||
if len(items) == 1 and isinstance(items[0], str):
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = items[0]
|
||||
else:
|
||||
content_items: list[PromptMessageContentUnionTypes] = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
content_items.append(TextPromptMessageContent(data=item))
|
||||
elif isinstance(item, CachePoint):
|
||||
continue
|
||||
elif _is_multi_modal_content(item):
|
||||
content_items.append(_map_multi_modal_user_content(item))
|
||||
else:
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported daemon tool message content: {type(item).__name__}"
|
||||
)
|
||||
content = content_items or None
|
||||
|
||||
return ToolPromptMessage(
|
||||
content=content, tool_call_id=part.tool_call_id, name=part.tool_name
|
||||
)
|
||||
|
||||
|
||||
def _map_model_response_to_prompt_message(
|
||||
message: ModelResponse,
|
||||
) -> AssistantPromptMessage | None:
|
||||
content_parts: list[PromptMessageContentUnionTypes] = []
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
if part.content:
|
||||
content_parts.append(TextPromptMessageContent(data=part.content))
|
||||
elif isinstance(part, ThinkingPart):
|
||||
if part.content:
|
||||
content_parts.append(
|
||||
TextPromptMessageContent(
|
||||
data=f"{_THINK_START}{part.content}{_THINK_END}"
|
||||
)
|
||||
)
|
||||
elif isinstance(part, FilePart):
|
||||
content_parts.append(_map_binary_content_to_prompt_content(part.content))
|
||||
elif isinstance(part, ToolCallPart):
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.tool_call_id or f"tool-call-{part.tool_name}",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.tool_name,
|
||||
arguments=part.args_as_json_str(),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(
|
||||
part, BuiltinToolCallPart | BuiltinToolReturnPart | CompactionPart
|
||||
):
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported response part for daemon adapter: {type(part).__name__}"
|
||||
)
|
||||
else:
|
||||
assert_never(part)
|
||||
|
||||
content = _normalize_prompt_content(content_parts)
|
||||
if content is None and not tool_calls:
|
||||
return None
|
||||
|
||||
return AssistantPromptMessage(content=content, tool_calls=tool_calls)
|
||||
|
||||
|
||||
def _map_user_prompt_content(
|
||||
content: str | Sequence[UserContent],
|
||||
) -> str | list[PromptMessageContentUnionTypes] | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
prompt_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, CachePoint):
|
||||
continue
|
||||
if isinstance(item, str):
|
||||
prompt_content.append(TextPromptMessageContent(data=item))
|
||||
elif isinstance(item, TextContent):
|
||||
prompt_content.append(TextPromptMessageContent(data=item.content))
|
||||
elif _is_multi_modal_content(item):
|
||||
prompt_content.append(_map_multi_modal_user_content(item))
|
||||
else:
|
||||
raise UnexpectedModelBehavior(f"Unsupported user prompt content: {type(item).__name__}")
|
||||
return _normalize_prompt_content(prompt_content)
|
||||
|
||||
|
||||
def _is_multi_modal_content(item: object) -> bool:
|
||||
return isinstance(
|
||||
item,
|
||||
ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | UploadedFile,
|
||||
)
|
||||
|
||||
|
||||
def _map_multi_modal_user_content(
|
||||
item: MultiModalContent,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
if isinstance(item, ImageUrl):
|
||||
detail = (
|
||||
ImagePromptMessageContent.DETAIL.HIGH
|
||||
if _get_detail(item) == _DETAIL_HIGH
|
||||
else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
return ImagePromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
detail=detail,
|
||||
)
|
||||
if isinstance(item, AudioUrl):
|
||||
return AudioPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, VideoUrl):
|
||||
return VideoPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, DocumentUrl):
|
||||
return DocumentPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, BinaryContent):
|
||||
return _map_binary_content_to_prompt_content(item)
|
||||
if isinstance(item, UploadedFile):
|
||||
raise UnexpectedModelBehavior(
|
||||
"UploadedFile content is not supported by the daemon adapter"
|
||||
)
|
||||
assert_never(item)
|
||||
|
||||
|
||||
def _map_binary_content_to_prompt_content(
|
||||
item: BinaryContent,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
filename = f"{item.identifier}.{item.format}"
|
||||
if item.is_image:
|
||||
detail = (
|
||||
ImagePromptMessageContent.DETAIL.HIGH
|
||||
if _get_detail(item) == _DETAIL_HIGH
|
||||
else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
return ImagePromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
detail=detail,
|
||||
)
|
||||
if item.is_audio:
|
||||
return AudioPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
if item.is_video:
|
||||
return VideoPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
if item.is_document:
|
||||
return DocumentPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported binary media type for daemon adapter: {item.media_type}"
|
||||
)
|
||||
|
||||
|
||||
def _normalize_prompt_content(
|
||||
content: list[PromptMessageContentUnionTypes],
|
||||
) -> str | list[PromptMessageContentUnionTypes] | None:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) == 1 and isinstance(content[0], TextPromptMessageContent):
|
||||
return content[0].data
|
||||
return content
|
||||
|
||||
|
||||
def _map_tool_definitions_to_prompt_tools(
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> list[PromptMessageTool] | None:
|
||||
tool_definitions = [
|
||||
*model_request_parameters.function_tools,
|
||||
*model_request_parameters.output_tools,
|
||||
]
|
||||
if not tool_definitions:
|
||||
return None
|
||||
|
||||
return [
|
||||
PromptMessageTool(
|
||||
name=tool_definition.name,
|
||||
description=tool_definition.description or "",
|
||||
parameters=cast(dict[str, object], tool_definition.parameters_json_schema),
|
||||
)
|
||||
for tool_definition in tool_definitions
|
||||
]
|
||||
|
||||
|
||||
def _map_model_settings_to_parameters(model_settings: ModelSettings | None) -> dict[str, object]:
|
||||
if not model_settings:
|
||||
return {}
|
||||
|
||||
parameters: dict[str, object] = {
|
||||
key: value
|
||||
for key, value in model_settings.items()
|
||||
if value is not None and key not in {"extra_body", "stop_sequences"}
|
||||
}
|
||||
|
||||
extra_body = model_settings.get("extra_body")
|
||||
if isinstance(extra_body, Mapping):
|
||||
parameters.update(cast(Mapping[str, object], extra_body))
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _get_stop_sequences(model_settings: ModelSettings | None) -> list[str] | None:
|
||||
if not model_settings:
|
||||
return None
|
||||
return list(model_settings.get("stop_sequences") or []) or None
|
||||
|
||||
|
||||
def _map_usage(usage: LLMUsage) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
|
||||
)
|
||||
|
||||
|
||||
def _normalize_finish_reason(finish_reason: str) -> FinishReason:
|
||||
lowered = finish_reason.lower()
|
||||
if lowered in {"stop", "length", "content_filter", "error", "tool_call"}:
|
||||
return cast(FinishReason, lowered)
|
||||
if lowered in {"tool_calls", "function_call", "function_calls"}:
|
||||
return "tool_call"
|
||||
return "error"
|
||||
|
||||
|
||||
def _chunk_to_stream_events(
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
chunk: LLMResultChunk,
|
||||
provider_name: str,
|
||||
embedded_thinking_parser: "_EmbeddedThinkingParser",
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
events: list[ModelResponseStreamEvent] = []
|
||||
message = chunk.delta.message
|
||||
|
||||
if isinstance(message.content, str):
|
||||
if message.content:
|
||||
events.extend(
|
||||
embedded_thinking_parser.parse(
|
||||
parts_manager, message.content, provider_name
|
||||
)
|
||||
)
|
||||
elif isinstance(message.content, list):
|
||||
for part in _map_assistant_content_to_response_parts(message.content):
|
||||
if isinstance(part, TextPart):
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=part.content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
events.append(parts_manager.handle_part(vendor_part_id=None, part=part))
|
||||
|
||||
for index, tool_call in enumerate(message.tool_calls):
|
||||
vendor_id = tool_call.id or f"chunk-{chunk.delta.index}-tool-{index}"
|
||||
events.append(
|
||||
parts_manager.handle_tool_call_part(
|
||||
vendor_part_id=vendor_id,
|
||||
tool_name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _map_assistant_content_to_response_parts(
|
||||
content: Sequence[PromptMessageContentUnionTypes],
|
||||
) -> list[ModelResponsePart]:
|
||||
response_parts: list[ModelResponsePart] = []
|
||||
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
if item.data:
|
||||
response_parts.extend(_parse_assistant_text_parts(item.data))
|
||||
elif isinstance(
|
||||
item,
|
||||
ImagePromptMessageContent
|
||||
| AudioPromptMessageContent
|
||||
| VideoPromptMessageContent
|
||||
| DocumentPromptMessageContent,
|
||||
):
|
||||
if item.url:
|
||||
raise UnexpectedModelBehavior(
|
||||
"URL-based assistant multimodal output is not supported by the daemon adapter"
|
||||
)
|
||||
if not item.base64_data:
|
||||
continue
|
||||
response_parts.append(
|
||||
FilePart(
|
||||
content=BinaryContent(
|
||||
data=base64.b64decode(item.base64_data),
|
||||
media_type=item.mime_type,
|
||||
),
|
||||
provider_name=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert_never(item)
|
||||
|
||||
return response_parts
|
||||
|
||||
|
||||
def _get_detail(item: ImageUrl | BinaryContent) -> str | None:
|
||||
metadata = item.vendor_metadata or {}
|
||||
detail = metadata.get("detail")
|
||||
return detail if isinstance(detail, str) else None
|
||||
|
||||
|
||||
def _parse_assistant_text_parts(content: str) -> list[ModelResponsePart]:
|
||||
response_parts: list[ModelResponsePart] = []
|
||||
cursor = 0
|
||||
|
||||
for match in _THINK_TAG_PATTERN.finditer(content):
|
||||
if match.start() > cursor:
|
||||
response_parts.append(
|
||||
TextPart(content=content[cursor : match.start()], provider_name=None)
|
||||
)
|
||||
|
||||
thinking_content = match.group(1).strip("\n")
|
||||
if thinking_content:
|
||||
response_parts.append(
|
||||
ThinkingPart(content=thinking_content, provider_name=None)
|
||||
)
|
||||
cursor = match.end()
|
||||
|
||||
if cursor < len(content):
|
||||
response_parts.append(TextPart(content=content[cursor:], provider_name=None))
|
||||
|
||||
if response_parts:
|
||||
return response_parts
|
||||
return [TextPart(content=content, provider_name=None)]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _EmbeddedThinkingParser:
|
||||
_pending: str = ""
|
||||
_inside_thinking: bool = False
|
||||
|
||||
def parse(
|
||||
self,
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
content: str,
|
||||
provider_name: str,
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
events: list[ModelResponseStreamEvent] = []
|
||||
buffer = self._pending + content
|
||||
self._pending = ""
|
||||
|
||||
while buffer:
|
||||
if self._inside_thinking:
|
||||
end_index = buffer.find(_THINK_CLOSE_TAG)
|
||||
if end_index >= 0:
|
||||
if end_index > 0:
|
||||
events.extend(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=buffer[:end_index],
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
buffer = buffer[end_index + len(_THINK_CLOSE_TAG) :]
|
||||
self._inside_thinking = False
|
||||
continue
|
||||
|
||||
safe_content, self._pending = _split_incomplete_tag_suffix(
|
||||
buffer, _THINK_CLOSE_TAG
|
||||
)
|
||||
if safe_content:
|
||||
events.extend(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=safe_content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
start_index = buffer.find(_THINK_OPEN_TAG)
|
||||
if start_index >= 0:
|
||||
if start_index > 0:
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=buffer[:start_index],
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
buffer = buffer[start_index + len(_THINK_OPEN_TAG) :]
|
||||
self._inside_thinking = True
|
||||
continue
|
||||
|
||||
safe_content, self._pending = _split_incomplete_tag_suffix(
|
||||
buffer, _THINK_OPEN_TAG
|
||||
)
|
||||
if safe_content:
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=safe_content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return events
|
||||
|
||||
def flush(
|
||||
self,
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
provider_name: str,
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
if not self._pending:
|
||||
return []
|
||||
|
||||
pending = self._pending
|
||||
self._pending = ""
|
||||
if self._inside_thinking:
|
||||
return list(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=pending,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
return list(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=pending,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _split_incomplete_tag_suffix(content: str, tag: str) -> tuple[str, str]:
|
||||
for suffix_length in range(len(tag) - 1, 0, -1):
|
||||
if content.endswith(tag[:suffix_length]):
|
||||
return content[:-suffix_length], content[-suffix_length:]
|
||||
return content, ""
|
||||
@ -1,252 +0,0 @@
|
||||
"""Dify plugin-daemon provider for Pydantic AI LLM adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import NoReturn
|
||||
|
||||
import httpx
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError
|
||||
from pydantic_ai.providers import Provider
|
||||
|
||||
_DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0
|
||||
|
||||
|
||||
class PluginDaemonBasicResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: object | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyPluginDaemonLLMClient:
|
||||
plugin_daemon_url: str
|
||||
plugin_daemon_api_key: str
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
provider: str
|
||||
user_id: str | None
|
||||
http_client: httpx.AsyncClient = field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
|
||||
|
||||
async def iter_llm_result_chunks(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
credentials: dict[str, object],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, object],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
) -> AsyncIterator[LLMResultChunk]:
|
||||
async for item in self._iter_stream_response(
|
||||
model_name=model,
|
||||
path=f"plugin/{self.tenant_id}/dispatch/llm/invoke",
|
||||
request_data={
|
||||
"provider": self.provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
response_model=LLMResultChunk,
|
||||
):
|
||||
yield item
|
||||
|
||||
async def _iter_stream_response[T: BaseModel](
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
path: str,
|
||||
request_data: Mapping[str, object],
|
||||
response_model: type[T],
|
||||
) -> AsyncIterator[T]:
|
||||
payload: dict[str, object] = {"data": _to_jsonable(request_data)}
|
||||
if self.user_id is not None:
|
||||
payload["user_id"] = self.user_id
|
||||
|
||||
headers = {
|
||||
"X-Api-Key": self.plugin_daemon_api_key,
|
||||
"X-Plugin-ID": self.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{self.plugin_daemon_url}/{path}"
|
||||
|
||||
async with self.http_client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.is_error:
|
||||
body = (await response.aread()).decode("utf-8", errors="replace")
|
||||
error = _decode_plugin_daemon_error_payload(body)
|
||||
if error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=error["error_type"],
|
||||
message=error["message"],
|
||||
status_code=response.status_code,
|
||||
body=error,
|
||||
)
|
||||
raise ModelHTTPError(response.status_code, model_name, body or None)
|
||||
|
||||
async for raw_line in response.aiter_lines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
|
||||
wrapped = PluginDaemonBasicResponse.model_validate_json(line)
|
||||
if wrapped.code != 0:
|
||||
error = _decode_plugin_daemon_error_payload(wrapped.message)
|
||||
if error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=error["error_type"],
|
||||
message=error["message"],
|
||||
body=error,
|
||||
)
|
||||
raise ModelAPIError(
|
||||
model_name,
|
||||
f"Plugin daemon returned error code {wrapped.code}: {wrapped.message}",
|
||||
)
|
||||
if wrapped.data is None:
|
||||
raise UnexpectedModelBehavior("Plugin daemon returned an empty stream item")
|
||||
yield response_model.model_validate(wrapped.data)
|
||||
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]):
|
||||
"""Pydantic AI provider for Dify plugin-daemon dispatch requests."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_provider: str
|
||||
plugin_daemon_url: str
|
||||
plugin_daemon_api_key: str = field(repr=False)
|
||||
user_id: str | None = None
|
||||
timeout: float | httpx.Timeout | None = _DEFAULT_DAEMON_TIMEOUT
|
||||
_client: DifyPluginDaemonLLMClient = field(init=False, repr=False)
|
||||
_own_http_client: httpx.AsyncClient | None = field(init=False, default=None, repr=False)
|
||||
_http_client_factory: Callable[[], httpx.AsyncClient] | None = field(init=False, default=None, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
|
||||
self._http_client_factory = self._make_http_client
|
||||
http_client = self._make_http_client()
|
||||
self._own_http_client = http_client
|
||||
self._client = DifyPluginDaemonLLMClient(
|
||||
plugin_daemon_url=self.plugin_daemon_url,
|
||||
plugin_daemon_api_key=self.plugin_daemon_api_key,
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.plugin_provider,
|
||||
user_id=self.user_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
def _make_http_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=self.timeout, trust_env=False)
|
||||
|
||||
@override
|
||||
def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
|
||||
self._client.http_client = http_client
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
return f"DifyPlugin/{self.plugin_provider}"
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
return self.plugin_daemon_url
|
||||
|
||||
@property
|
||||
@override
|
||||
def client(self) -> DifyPluginDaemonLLMClient:
|
||||
return self._client
|
||||
|
||||
|
||||
def _to_jsonable(value: object) -> object:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_jsonable(item) for key, item in value.items()}
|
||||
if isinstance(value, list | tuple):
|
||||
return [_to_jsonable(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None:
|
||||
try:
|
||||
parsed = json.loads(raw_message)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
|
||||
error_type = parsed.get("error_type")
|
||||
message = parsed.get("message")
|
||||
if not isinstance(error_type, str) or not isinstance(message, str):
|
||||
return None
|
||||
return {"error_type": error_type, "message": message}
|
||||
|
||||
|
||||
def _raise_plugin_daemon_error(
|
||||
*,
|
||||
model_name: str,
|
||||
error_type: str,
|
||||
message: str,
|
||||
status_code: int | None = None,
|
||||
body: object | None = None,
|
||||
) -> NoReturn:
|
||||
http_error_body = body or {"error_type": error_type, "message": message}
|
||||
|
||||
match error_type:
|
||||
case "PluginInvokeError":
|
||||
nested_error = _decode_plugin_daemon_error_payload(message)
|
||||
if nested_error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=nested_error["error_type"],
|
||||
message=nested_error["message"],
|
||||
status_code=status_code,
|
||||
body=nested_error,
|
||||
)
|
||||
raise ModelAPIError(model_name, message)
|
||||
case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError":
|
||||
raise ModelHTTPError(status_code or 401, model_name, http_error_body)
|
||||
case "PluginPermissionDeniedError":
|
||||
raise ModelHTTPError(status_code or 403, model_name, http_error_body)
|
||||
case (
|
||||
"PluginDaemonBadRequestError"
|
||||
| "InvokeBadRequestError"
|
||||
| "CredentialsValidateFailedError"
|
||||
| "PluginUniqueIdentifierError"
|
||||
):
|
||||
raise ModelHTTPError(status_code or 400, model_name, http_error_body)
|
||||
case "EndpointSetupFailedError" | "TriggerProviderCredentialValidationError":
|
||||
raise UserError(message)
|
||||
case "PluginDaemonNotFoundError" | "PluginNotFoundError":
|
||||
raise ModelHTTPError(status_code or 404, model_name, http_error_body)
|
||||
case "InvokeRateLimitError":
|
||||
raise ModelHTTPError(status_code or 429, model_name, http_error_body)
|
||||
case "PluginDaemonInternalServerError" | "PluginDaemonInnerError":
|
||||
raise ModelHTTPError(status_code or 500, model_name, http_error_body)
|
||||
case "InvokeConnectionError" | "InvokeServerUnavailableError":
|
||||
raise ModelHTTPError(status_code or 503, model_name, http_error_body)
|
||||
case _:
|
||||
raise ModelAPIError(model_name, f"{error_type}: {message}")
|
||||
@ -1,56 +0,0 @@
|
||||
"""Pydantic AI agent construction for runtime profiles.
|
||||
|
||||
The initial server exposes only a credential-free ``test`` profile. The factory
|
||||
keeps model selection out of ``AgentRunRunner`` so production model profiles can
|
||||
be added without changing storage or HTTP contracts.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, cast
|
||||
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.messages import UserContent
|
||||
from pydantic_ai.models.test import TestModel
|
||||
|
||||
from agenton.layers.types import PydanticAIPrompt, PydanticAITool
|
||||
from dify_agent.server.schemas import AgentProfileConfig
|
||||
|
||||
|
||||
def create_agent(
|
||||
profile: AgentProfileConfig,
|
||||
*,
|
||||
system_prompts: Sequence[PydanticAIPrompt[object]],
|
||||
tools: Sequence[PydanticAITool[object]],
|
||||
) -> Agent[None, str]:
|
||||
"""Create the pydantic-ai agent for one run."""
|
||||
if profile.provider == "test":
|
||||
return Agent[None, str](
|
||||
TestModel(custom_output_text=profile.output_text),
|
||||
output_type=str,
|
||||
system_prompt=materialize_static_system_prompts(system_prompts),
|
||||
tools=tools,
|
||||
)
|
||||
raise ValueError(f"Unsupported agent profile provider: {profile.provider}")
|
||||
|
||||
|
||||
def materialize_static_system_prompts(system_prompts: Sequence[PydanticAIPrompt[object]]) -> list[str]:
|
||||
"""Convert MVP static prompt callables into strings for pydantic-ai."""
|
||||
result: list[str] = []
|
||||
for prompt in system_prompts:
|
||||
if isinstance(prompt, str):
|
||||
result.append(prompt)
|
||||
elif callable(prompt):
|
||||
result.append(cast(Callable[[], str], prompt)())
|
||||
else:
|
||||
raise TypeError(f"Unsupported system prompt type: {type(prompt).__qualname__}")
|
||||
return result
|
||||
|
||||
|
||||
def normalize_user_input(user_prompts: Sequence[UserContent]) -> str | Sequence[UserContent]:
|
||||
"""Return the pydantic-ai run input while preserving multi-part prompts."""
|
||||
if len(user_prompts) == 1 and isinstance(user_prompts[0], str):
|
||||
return user_prompts[0]
|
||||
return list(user_prompts)
|
||||
|
||||
|
||||
__all__ = ["create_agent", "materialize_static_system_prompts", "normalize_user_input"]
|
||||
@ -1,53 +0,0 @@
|
||||
"""Safe Agenton compositor construction for API-submitted configs.
|
||||
|
||||
Only explicitly registered layer types are constructible here. The MVP registry
|
||||
contains ``PromptLayer`` so callers can provide system/user prompt fragments while
|
||||
the runtime preserves hooks for richer profiles later.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from pydantic_ai.messages import UserContent
|
||||
|
||||
from agenton.compositor import Compositor, CompositorConfig, LayerRegistry
|
||||
from agenton.layers.types import AllPromptTypes, AllToolTypes, AllUserPromptTypes, PydanticAIPrompt, PydanticAITool
|
||||
from agenton_collections.layers.plain.basic import PromptLayer
|
||||
from agenton_collections.transformers.pydantic_ai import PYDANTIC_AI_TRANSFORMERS
|
||||
|
||||
|
||||
def create_default_layer_registry() -> LayerRegistry:
|
||||
"""Return the server registry of safe config-constructible layers."""
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
return registry
|
||||
|
||||
|
||||
def build_pydantic_ai_compositor(
|
||||
config: CompositorConfig,
|
||||
) -> Compositor[
|
||||
PydanticAIPrompt[object],
|
||||
PydanticAITool[object],
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
UserContent,
|
||||
AllUserPromptTypes,
|
||||
]:
|
||||
"""Build a Pydantic AI-ready compositor from a validated config."""
|
||||
return cast(
|
||||
Compositor[
|
||||
PydanticAIPrompt[object],
|
||||
PydanticAITool[object],
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
UserContent,
|
||||
AllUserPromptTypes,
|
||||
],
|
||||
Compositor.from_config(
|
||||
config,
|
||||
registry=create_default_layer_registry(),
|
||||
**PYDANTIC_AI_TRANSFORMERS, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["build_pydantic_ai_compositor", "create_default_layer_registry"]
|
||||
@ -1,144 +0,0 @@
|
||||
"""Event sink contracts used by the runner and storage adapters.
|
||||
|
||||
The runner only needs append-only event writes and status transitions, so tests
|
||||
can use ``InMemoryRunEventSink`` without Redis. Production storage implements the
|
||||
same protocol with Redis streams in ``dify_agent.storage.redis_run_store``.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.server.schemas import (
|
||||
AgentOutputRunEvent,
|
||||
AgentOutputRunEventData,
|
||||
EmptyRunEventData,
|
||||
PydanticAIStreamRunEvent,
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatus,
|
||||
RunSucceededEvent,
|
||||
SessionSnapshotRunEvent,
|
||||
utc_now,
|
||||
)
|
||||
|
||||
|
||||
class RunEventSink(Protocol):
|
||||
"""Boundary used by runtime code to publish observable run progress."""
|
||||
|
||||
async def append_event(self, event: RunEvent) -> str:
|
||||
"""Persist ``event`` and return its cursor id."""
|
||||
...
|
||||
|
||||
async def update_status(self, run_id: str, status: RunStatus, error: str | None = None) -> None:
|
||||
"""Persist the current run status."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryRunEventSink:
|
||||
"""Small async-compatible sink for local unit tests and examples."""
|
||||
|
||||
events: dict[str, list[RunEvent]]
|
||||
statuses: dict[str, RunStatus]
|
||||
errors: dict[str, str | None]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events = defaultdict(list)
|
||||
self.statuses = {}
|
||||
self.errors = {}
|
||||
|
||||
async def append_event(self, event: RunEvent) -> str:
|
||||
"""Store an event and assign a monotonic per-run cursor."""
|
||||
event_id = str(len(self.events[event.run_id]) + 1)
|
||||
stored = event.model_copy(update={"id": event_id})
|
||||
self.events[event.run_id].append(stored)
|
||||
return event_id
|
||||
|
||||
async def update_status(self, run_id: str, status: RunStatus, error: str | None = None) -> None:
|
||||
"""Record the latest status; timestamps are owned by run stores."""
|
||||
self.statuses[run_id] = status
|
||||
self.errors[run_id] = error
|
||||
|
||||
|
||||
async def emit_run_event(
|
||||
sink: RunEventSink,
|
||||
*,
|
||||
event: RunEvent,
|
||||
) -> str:
|
||||
"""Append an already typed public run event."""
|
||||
return await sink.append_event(event)
|
||||
|
||||
|
||||
async def emit_run_started(sink: RunEventSink, *, run_id: str) -> str:
|
||||
"""Emit the first lifecycle event for one run."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=RunStartedEvent(run_id=run_id, data=EmptyRunEventData(), created_at=utc_now()),
|
||||
)
|
||||
|
||||
|
||||
async def emit_pydantic_ai_event(sink: RunEventSink, *, run_id: str, data: AgentStreamEvent) -> str:
|
||||
"""Emit one typed Pydantic AI stream event."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=PydanticAIStreamRunEvent(run_id=run_id, data=data, created_at=utc_now()),
|
||||
)
|
||||
|
||||
|
||||
async def emit_agent_output(sink: RunEventSink, *, run_id: str, output: str) -> str:
|
||||
"""Emit the final output text produced by the agent."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=AgentOutputRunEvent(
|
||||
run_id=run_id,
|
||||
data=AgentOutputRunEventData(output=output),
|
||||
created_at=utc_now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def emit_session_snapshot(sink: RunEventSink, *, run_id: str, data: CompositorSessionSnapshot) -> str:
|
||||
"""Emit the typed Agenton session snapshot for later resumption."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=SessionSnapshotRunEvent(run_id=run_id, data=data, created_at=utc_now()),
|
||||
)
|
||||
|
||||
|
||||
async def emit_run_succeeded(sink: RunEventSink, *, run_id: str) -> str:
|
||||
"""Emit the terminal success lifecycle event."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=RunSucceededEvent(run_id=run_id, data=EmptyRunEventData(), created_at=utc_now()),
|
||||
)
|
||||
|
||||
|
||||
async def emit_run_failed(
|
||||
sink: RunEventSink,
|
||||
*,
|
||||
run_id: str,
|
||||
error: str,
|
||||
reason: str | None = None,
|
||||
) -> str:
|
||||
"""Emit the terminal failure lifecycle event."""
|
||||
return await emit_run_event(
|
||||
sink,
|
||||
event=RunFailedEvent(run_id=run_id, data=RunFailedEventData(error=error, reason=reason), created_at=utc_now()),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InMemoryRunEventSink",
|
||||
"RunEventSink",
|
||||
"emit_agent_output",
|
||||
"emit_pydantic_ai_event",
|
||||
"emit_run_event",
|
||||
"emit_run_failed",
|
||||
"emit_run_started",
|
||||
"emit_run_succeeded",
|
||||
"emit_session_snapshot",
|
||||
]
|
||||
@ -1,140 +0,0 @@
|
||||
"""In-process scheduling for Dify Agent runs.
|
||||
|
||||
The scheduler is intentionally process-local: it persists a run record, starts an
|
||||
``asyncio.Task`` for ``AgentRunRunner.run()``, and keeps only a transient active
|
||||
task registry. Redis remains the durable source for status and event streams, but
|
||||
there is no Redis job queue or cross-process handoff. If the process crashes,
|
||||
currently active runs are lost until an external operator marks or retries them.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol
|
||||
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.event_sink import RunEventSink, emit_run_failed
|
||||
from dify_agent.runtime.runner import AgentRunRunner
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerStoppingError(RuntimeError):
|
||||
"""Raised when a create-run request arrives after shutdown has started."""
|
||||
|
||||
|
||||
class RunStore(RunEventSink, Protocol):
|
||||
"""Persistence boundary needed by the scheduler."""
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Persist a new run record and return it with status ``running``."""
|
||||
...
|
||||
|
||||
|
||||
class RunnableRun(Protocol):
|
||||
"""Executable unit for one scheduled run."""
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run until terminal status/events have been written or cancellation occurs."""
|
||||
...
|
||||
|
||||
|
||||
type RunRunnerFactory = Callable[[RunRecord], RunnableRun]
|
||||
|
||||
|
||||
class RunScheduler:
|
||||
"""Owns process-local run tasks and best-effort graceful shutdown.
|
||||
|
||||
``active_tasks`` is mutated only on the event loop that calls ``create_run``
|
||||
and ``shutdown``. The task registry is not durable; it exists so the lifespan
|
||||
hook can wait for in-flight work and mark cancelled runs failed before Redis is
|
||||
closed. A lock guards the stopping flag, run persistence, and task
|
||||
registration so shutdown cannot complete while a run is between record
|
||||
creation and active-task tracking.
|
||||
"""
|
||||
|
||||
store: RunStore
|
||||
shutdown_grace_seconds: float
|
||||
active_tasks: dict[str, asyncio.Task[None]]
|
||||
stopping: bool
|
||||
runner_factory: RunRunnerFactory
|
||||
_lifecycle_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
store: RunStore,
|
||||
shutdown_grace_seconds: float = 30,
|
||||
runner_factory: RunRunnerFactory | None = None,
|
||||
) -> None:
|
||||
self.store = store
|
||||
self.shutdown_grace_seconds = shutdown_grace_seconds
|
||||
self.active_tasks = {}
|
||||
self.stopping = False
|
||||
self.runner_factory = runner_factory or self._default_runner_factory
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Validate, persist, and schedule one run in the current process.
|
||||
|
||||
The returned record is already ``running``. The background task is removed
|
||||
from ``active_tasks`` when it finishes, regardless of success or failure.
|
||||
"""
|
||||
compositor = build_pydantic_ai_compositor(request.compositor)
|
||||
if not has_non_blank_user_prompt(compositor.user_prompts):
|
||||
raise ValueError(EMPTY_USER_PROMPTS_ERROR)
|
||||
|
||||
async with self._lifecycle_lock:
|
||||
if self.stopping:
|
||||
raise SchedulerStoppingError("run scheduler is shutting down")
|
||||
record = await self.store.create_run(request)
|
||||
task = asyncio.create_task(self._run_record(record), name=f"dify-agent-run-{record.run_id}")
|
||||
self.active_tasks[record.run_id] = task
|
||||
task.add_done_callback(lambda _task, run_id=record.run_id: self.active_tasks.pop(run_id, None))
|
||||
return record
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop accepting runs, wait briefly, then cancel and fail unfinished runs."""
|
||||
async with self._lifecycle_lock:
|
||||
self.stopping = True
|
||||
if not self.active_tasks:
|
||||
return
|
||||
tasks_by_run_id = dict(self.active_tasks)
|
||||
done, pending = await asyncio.wait(tasks_by_run_id.values(), timeout=self.shutdown_grace_seconds)
|
||||
del done
|
||||
if not pending:
|
||||
return
|
||||
|
||||
pending_run_ids = [run_id for run_id, task in tasks_by_run_id.items() if task in pending]
|
||||
for task in pending:
|
||||
_ = task.cancel()
|
||||
_ = await asyncio.gather(*pending, return_exceptions=True)
|
||||
for run_id in pending_run_ids:
|
||||
await self._mark_cancelled_run_failed(run_id)
|
||||
|
||||
async def _run_record(self, record: RunRecord) -> None:
|
||||
"""Execute a stored run and log failures already reflected in events."""
|
||||
try:
|
||||
await self.runner_factory(record).run()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("scheduled run failed", extra={"run_id": record.run_id})
|
||||
|
||||
def _default_runner_factory(self, record: RunRecord) -> RunnableRun:
|
||||
"""Create the production runner for a stored run record."""
|
||||
return AgentRunRunner(sink=self.store, request=record.request, run_id=record.run_id)
|
||||
|
||||
async def _mark_cancelled_run_failed(self, run_id: str) -> None:
|
||||
"""Best-effort failure event/status for shutdown-cancelled runs."""
|
||||
message = "run cancelled during server shutdown"
|
||||
try:
|
||||
_ = await emit_run_failed(self.store, run_id=run_id, error=message, reason="shutdown")
|
||||
await self.store.update_status(run_id, "failed", message)
|
||||
except Exception:
|
||||
logger.exception("failed to mark cancelled run failed", extra={"run_id": run_id})
|
||||
|
||||
|
||||
__all__ = ["RunScheduler", "SchedulerStoppingError"]
|
||||
@ -1,92 +0,0 @@
|
||||
"""Runtime execution for one scheduled Dify Agent run.
|
||||
|
||||
The runner is storage-agnostic: it builds an Agenton compositor, enters or
|
||||
resumes its session, runs pydantic-ai with ``compositor.user_prompts`` as the user
|
||||
input, emits stream events, suspends the session on exit, snapshots it, and then
|
||||
publishes a terminal success or failure event.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.runtime.agent_factory import create_agent, normalize_user_input
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.event_sink import (
|
||||
RunEventSink,
|
||||
emit_agent_output,
|
||||
emit_pydantic_ai_event,
|
||||
emit_run_failed,
|
||||
emit_run_started,
|
||||
emit_run_succeeded,
|
||||
emit_session_snapshot,
|
||||
)
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest
|
||||
|
||||
|
||||
class AgentRunValidationError(ValueError):
|
||||
"""Raised when a run request is valid JSON but cannot execute."""
|
||||
|
||||
|
||||
class AgentRunRunner:
|
||||
"""Executes one run and writes only public run events to its sink."""
|
||||
|
||||
sink: RunEventSink
|
||||
|
||||
request: CreateRunRequest
|
||||
run_id: str
|
||||
|
||||
def __init__(self, *, sink: RunEventSink, request: CreateRunRequest, run_id: str) -> None:
|
||||
self.sink = sink
|
||||
self.request = request
|
||||
self.run_id = run_id
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Execute the run and emit the documented event sequence."""
|
||||
await self.sink.update_status(self.run_id, "running")
|
||||
_ = await emit_run_started(self.sink, run_id=self.run_id)
|
||||
|
||||
try:
|
||||
output, session_snapshot = await self._run_agent()
|
||||
except Exception as exc:
|
||||
message = str(exc) or type(exc).__name__
|
||||
_ = await emit_run_failed(self.sink, run_id=self.run_id, error=message)
|
||||
await self.sink.update_status(self.run_id, "failed", message)
|
||||
raise
|
||||
|
||||
_ = await emit_agent_output(self.sink, run_id=self.run_id, output=output)
|
||||
_ = await emit_session_snapshot(self.sink, run_id=self.run_id, data=session_snapshot)
|
||||
_ = await emit_run_succeeded(self.sink, run_id=self.run_id)
|
||||
await self.sink.update_status(self.run_id, "succeeded")
|
||||
|
||||
async def _run_agent(self) -> tuple[str, CompositorSessionSnapshot]:
|
||||
"""Run pydantic-ai inside an entered Agenton session."""
|
||||
compositor = build_pydantic_ai_compositor(self.request.compositor)
|
||||
session = (
|
||||
compositor.session_from_snapshot(self.request.session_snapshot)
|
||||
if self.request.session_snapshot is not None
|
||||
else compositor.new_session()
|
||||
)
|
||||
async with compositor.enter(session) as active_session:
|
||||
active_session.suspend_on_exit()
|
||||
user_prompts = compositor.user_prompts
|
||||
if not has_non_blank_user_prompt(user_prompts):
|
||||
raise AgentRunValidationError(EMPTY_USER_PROMPTS_ERROR)
|
||||
|
||||
async def handle_events(_ctx: object, events: AsyncIterable[AgentStreamEvent]) -> None:
|
||||
async for event in events:
|
||||
_ = await emit_pydantic_ai_event(self.sink, run_id=self.run_id, data=event)
|
||||
|
||||
agent = create_agent(
|
||||
self.request.agent_profile,
|
||||
system_prompts=compositor.prompts,
|
||||
tools=compositor.tools,
|
||||
)
|
||||
result = await agent.run(normalize_user_input(user_prompts), event_stream_handler=handle_events)
|
||||
|
||||
return result.output, compositor.snapshot_session(session)
|
||||
|
||||
|
||||
__all__ = ["AgentRunRunner", "AgentRunValidationError"]
|
||||
@ -1,29 +0,0 @@
|
||||
"""Validation for effective user prompts produced by Agenton compositors.
|
||||
|
||||
Validation happens after safe compositor construction so scheduler and runner
|
||||
paths use the same semantics as the actual pydantic-ai input. Blank string fragments do not
|
||||
count as meaningful input; non-string ``UserContent`` is treated as intentional
|
||||
content because rich media/message parts do not have a universal whitespace
|
||||
representation.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic_ai.messages import UserContent
|
||||
|
||||
|
||||
EMPTY_USER_PROMPTS_ERROR = "compositor.user_prompts must not be empty"
|
||||
|
||||
|
||||
def has_non_blank_user_prompt(user_prompts: Sequence[UserContent]) -> bool:
|
||||
"""Return whether composed user prompts contain meaningful input."""
|
||||
for prompt in user_prompts:
|
||||
if isinstance(prompt, str):
|
||||
if prompt.strip():
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
__all__ = ["EMPTY_USER_PROMPTS_ERROR", "has_non_blank_user_prompt"]
|
||||
@ -1,59 +0,0 @@
|
||||
"""FastAPI application factory for the Dify Agent run server.
|
||||
|
||||
The HTTP process owns Redis clients, route wiring, and a process-local scheduler.
|
||||
Run execution happens in background ``asyncio`` tasks rather than request
|
||||
handlers, so client disconnects do not cancel the agent runtime. Redis persists
|
||||
run records and per-run event streams with configured retention only; it is not
|
||||
used as a job queue.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler
|
||||
from dify_agent.server.routes.runs import create_runs_router
|
||||
from dify_agent.server.settings import ServerSettings
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
|
||||
|
||||
def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
"""Build the FastAPI app with one shared Redis store and local scheduler."""
|
||||
resolved_settings = settings or ServerSettings()
|
||||
state: dict[str, RedisRunStore | RunScheduler] = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
redis = Redis.from_url(resolved_settings.redis_url)
|
||||
store = RedisRunStore(
|
||||
redis,
|
||||
prefix=resolved_settings.redis_prefix,
|
||||
run_retention_seconds=resolved_settings.run_retention_seconds,
|
||||
)
|
||||
scheduler = RunScheduler(store=store, shutdown_grace_seconds=resolved_settings.shutdown_grace_seconds)
|
||||
state["store"] = store
|
||||
state["scheduler"] = scheduler
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await scheduler.shutdown()
|
||||
await redis.aclose()
|
||||
|
||||
app = FastAPI(title="Dify Agent Run Server", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
def get_store() -> RedisRunStore:
|
||||
return state["store"] # pyright: ignore[reportReturnType]
|
||||
|
||||
def get_scheduler() -> RunScheduler:
|
||||
return state["scheduler"] # pyright: ignore[reportReturnType]
|
||||
|
||||
app.include_router(create_runs_router(get_store, get_scheduler))
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
__all__ = ["app", "create_app"]
|
||||
@ -1,96 +0,0 @@
|
||||
"""FastAPI routes for asynchronous agent runs.
|
||||
|
||||
Controllers translate known validation and shutdown errors into HTTP status codes.
|
||||
Unexpected scheduler or storage failures are intentionally left for FastAPI's
|
||||
server-error handling so infrastructure problems are not reported as client input
|
||||
errors. Created runs are scheduled in the current process and observed through
|
||||
status polling or SSE replay backed by Redis event streams.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse
|
||||
from dify_agent.server.sse import sse_event_stream
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError
|
||||
|
||||
|
||||
def create_runs_router(get_store: Callable[[], RedisRunStore], get_scheduler: Callable[[], RunScheduler]) -> APIRouter:
|
||||
"""Create routes bound to the application's store dependency provider."""
|
||||
router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
|
||||
async def store_dep() -> RedisRunStore:
|
||||
return get_store()
|
||||
|
||||
async def scheduler_dep() -> RunScheduler:
|
||||
return get_scheduler()
|
||||
|
||||
@router.post("", response_model=CreateRunResponse, status_code=202)
|
||||
async def create_run(
|
||||
request: CreateRunRequest,
|
||||
scheduler: Annotated[RunScheduler, Depends(scheduler_dep)],
|
||||
) -> CreateRunResponse:
|
||||
try:
|
||||
compositor = build_pydantic_ai_compositor(request.compositor)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
if not has_non_blank_user_prompt(compositor.user_prompts):
|
||||
raise HTTPException(status_code=422, detail=EMPTY_USER_PROMPTS_ERROR)
|
||||
|
||||
try:
|
||||
record = await scheduler.create_run(request)
|
||||
except SchedulerStoppingError as exc:
|
||||
raise HTTPException(status_code=503, detail="run scheduler is shutting down") from exc
|
||||
return CreateRunResponse(run_id=record.run_id, status=record.status)
|
||||
|
||||
@router.get("/{run_id}", response_model=RunStatusResponse)
|
||||
async def get_run_status(run_id: str, store: Annotated[RedisRunStore, Depends(store_dep)]) -> RunStatusResponse:
|
||||
try:
|
||||
record = await store.get_run(run_id)
|
||||
except RunNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail="run not found") from exc
|
||||
return RunStatusResponse(
|
||||
run_id=record.run_id,
|
||||
status=record.status,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
error=record.error,
|
||||
)
|
||||
|
||||
@router.get("/{run_id}/events", response_model=RunEventsResponse)
|
||||
async def get_run_events(
|
||||
run_id: str,
|
||||
store: Annotated[RedisRunStore, Depends(store_dep)],
|
||||
after: str = Query(default="0-0"),
|
||||
limit: int = Query(default=100, ge=1, le=500),
|
||||
) -> RunEventsResponse:
|
||||
try:
|
||||
return await store.get_events(run_id, after=after, limit=limit)
|
||||
except RunNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail="run not found") from exc
|
||||
|
||||
@router.get("/{run_id}/events/sse")
|
||||
async def stream_run_events(
|
||||
run_id: str,
|
||||
store: Annotated[RedisRunStore, Depends(store_dep)],
|
||||
last_event_id: Annotated[str | None, Header(alias="Last-Event-ID")] = None,
|
||||
after: str | None = Query(default=None),
|
||||
) -> StreamingResponse:
|
||||
cursor = after or last_event_id or "0-0"
|
||||
try:
|
||||
_ = await store.get_run(run_id)
|
||||
events = store.iter_events(run_id, after=cursor)
|
||||
return StreamingResponse(sse_event_stream(events), media_type="text/event-stream")
|
||||
except RunNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail="run not found") from exc
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_runs_router"]
|
||||
@ -1,228 +0,0 @@
|
||||
"""Public API schemas for the Dify Agent run server.
|
||||
|
||||
The server accepts only registry-backed Agenton compositor configs. This keeps
|
||||
HTTP input data-only and prevents unsafe import-path construction. Run events are
|
||||
append-only records; Redis stream ids (or in-memory equivalents in tests) are the
|
||||
public cursors used by polling and SSE replay. Event envelopes keep the public
|
||||
``id``/``run_id``/``type``/``data``/``created_at`` shape, but each ``type`` has a
|
||||
typed ``data`` model so OpenAPI, Redis replay, and runtime producers agree on the
|
||||
payload contract.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_validator
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorConfig, CompositorSessionSnapshot
|
||||
|
||||
|
||||
RunStatus = Literal["running", "succeeded", "failed"]
|
||||
RunEventType = Literal[
|
||||
"run_started",
|
||||
"pydantic_ai_event",
|
||||
"agent_output",
|
||||
"session_snapshot",
|
||||
"run_succeeded",
|
||||
"run_failed",
|
||||
]
|
||||
|
||||
|
||||
def new_run_id() -> str:
|
||||
"""Return a stable external run id."""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return the timestamp format used by public schemas."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class AgentProfileConfig(BaseModel):
|
||||
"""Minimal model profile for the MVP runner.
|
||||
|
||||
``test`` uses pydantic-ai's ``TestModel`` and is credential-free. Other
|
||||
profiles can be added behind this schema without changing run/event storage.
|
||||
"""
|
||||
|
||||
provider: Literal["test"] = "test"
|
||||
output_text: str = "Hello from the Dify Agent test model."
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
"""Request body for creating one async agent run."""
|
||||
|
||||
compositor: CompositorConfig
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
agent_profile: AgentProfileConfig = Field(default_factory=AgentProfileConfig)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunResponse(BaseModel):
|
||||
"""Response returned after a run has been persisted and scheduled locally."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStatusResponse(BaseModel):
|
||||
"""Current server-side status for one run."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class EmptyRunEventData(BaseModel):
|
||||
"""Typed empty payload for lifecycle events that carry no extra data."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentOutputRunEventData(BaseModel):
|
||||
"""Final agent output payload emitted before the session snapshot."""
|
||||
|
||||
output: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunFailedEventData(BaseModel):
|
||||
"""Terminal failure payload shown to polling and SSE consumers."""
|
||||
|
||||
error: str
|
||||
reason: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class BaseRunEvent(BaseModel):
|
||||
"""Shared append-only event envelope visible through polling and SSE."""
|
||||
|
||||
id: str | None = None
|
||||
run_id: str
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStartedEvent(BaseRunEvent):
|
||||
"""Run lifecycle event emitted before runtime execution starts."""
|
||||
|
||||
type: Literal["run_started"] = "run_started"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class PydanticAIStreamRunEvent(BaseRunEvent):
|
||||
"""Pydantic AI stream event using the upstream typed event model."""
|
||||
|
||||
type: Literal["pydantic_ai_event"] = "pydantic_ai_event"
|
||||
data: AgentStreamEvent
|
||||
|
||||
|
||||
class AgentOutputRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the final agent output string."""
|
||||
|
||||
type: Literal["agent_output"] = "agent_output"
|
||||
data: AgentOutputRunEventData
|
||||
|
||||
|
||||
class SessionSnapshotRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the resumable Agenton session snapshot."""
|
||||
|
||||
type: Literal["session_snapshot"] = "session_snapshot"
|
||||
data: CompositorSessionSnapshot
|
||||
|
||||
|
||||
class RunSucceededEvent(BaseRunEvent):
|
||||
"""Terminal success event emitted after output and session snapshot."""
|
||||
|
||||
type: Literal["run_succeeded"] = "run_succeeded"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class RunFailedEvent(BaseRunEvent):
|
||||
"""Terminal failure event emitted before the run status becomes failed."""
|
||||
|
||||
type: Literal["run_failed"] = "run_failed"
|
||||
data: RunFailedEventData
|
||||
|
||||
|
||||
|
||||
RunEvent: TypeAlias = Annotated[
|
||||
RunStartedEvent
|
||||
| PydanticAIStreamRunEvent
|
||||
| AgentOutputRunEvent
|
||||
| SessionSnapshotRunEvent
|
||||
| RunSucceededEvent
|
||||
| RunFailedEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
RUN_EVENT_ADAPTER = TypeAdapter(RunEvent)
|
||||
|
||||
|
||||
class RunEventsResponse(BaseModel):
|
||||
"""Cursor-paginated event log response."""
|
||||
|
||||
run_id: str
|
||||
events: list[RunEvent]
|
||||
next_cursor: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunRecord(BaseModel):
|
||||
"""Internal representation persisted for status reads."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
request: CreateRunRequest
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("updated_at")
|
||||
@classmethod
|
||||
def updated_at_must_be_timezone_aware(cls, value: datetime) -> datetime:
|
||||
"""Reject naive timestamps before they become JSON API values."""
|
||||
if value.tzinfo is None:
|
||||
raise ValueError("updated_at must be timezone-aware")
|
||||
return value
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentProfileConfig",
|
||||
"AgentOutputRunEvent",
|
||||
"AgentOutputRunEventData",
|
||||
"BaseRunEvent",
|
||||
"CreateRunRequest",
|
||||
"CreateRunResponse",
|
||||
"EmptyRunEventData",
|
||||
"PydanticAIStreamRunEvent",
|
||||
"RUN_EVENT_ADAPTER",
|
||||
"RunEvent",
|
||||
"RunEventsResponse",
|
||||
"RunFailedEvent",
|
||||
"RunFailedEventData",
|
||||
"RunRecord",
|
||||
"RunStartedEvent",
|
||||
"RunStatus",
|
||||
"RunStatusResponse",
|
||||
"RunSucceededEvent",
|
||||
"SessionSnapshotRunEvent",
|
||||
"new_run_id",
|
||||
"utc_now",
|
||||
]
|
||||
@ -1,26 +0,0 @@
|
||||
"""Configuration for the FastAPI run server."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
DEFAULT_RUN_RETENTION_SECONDS = 3 * 24 * 60 * 60
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Environment-backed settings for Redis persistence, retention, and local scheduling."""
|
||||
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
redis_prefix: str = "dify-agent"
|
||||
shutdown_grace_seconds: float = 30
|
||||
run_retention_seconds: int = Field(default=DEFAULT_RUN_RETENTION_SECONDS, ge=1)
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_prefix="DIFY_AGENT_",
|
||||
env_file=(".env", "dify-agent/.env"),
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DEFAULT_RUN_RETENTION_SECONDS", "ServerSettings"]
|
||||
@ -1,29 +0,0 @@
|
||||
"""Server-sent event formatting for run event replay.
|
||||
|
||||
SSE frames use the run event id as ``id`` and the run event type as ``event`` so
|
||||
browsers can resume with ``Last-Event-ID`` while clients can subscribe by event
|
||||
name. Payload data is the full public ``RunEvent`` JSON object.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
|
||||
from dify_agent.server.schemas import RUN_EVENT_ADAPTER, RunEvent
|
||||
|
||||
|
||||
def format_sse_event(event: RunEvent) -> str:
|
||||
"""Serialize one event as an SSE frame."""
|
||||
lines: list[str] = []
|
||||
if event.id is not None:
|
||||
lines.append(f"id: {event.id}")
|
||||
lines.append(f"event: {event.type}")
|
||||
lines.append(f"data: {RUN_EVENT_ADAPTER.dump_json(event).decode()}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
async def sse_event_stream(events: AsyncIterable[RunEvent]) -> AsyncIterator[str]:
|
||||
"""Yield formatted SSE frames from public run events."""
|
||||
async for event in events:
|
||||
yield format_sse_event(event)
|
||||
|
||||
|
||||
__all__ = ["format_sse_event", "sse_event_stream"]
|
||||
@ -1,14 +0,0 @@
|
||||
"""Redis key helpers for run records and per-run event streams."""
|
||||
|
||||
|
||||
def run_record_key(prefix: str, run_id: str) -> str:
|
||||
"""Return the Redis string key holding one serialized run record."""
|
||||
return f"{prefix}:runs:{run_id}:record"
|
||||
|
||||
|
||||
def run_events_key(prefix: str, run_id: str) -> str:
|
||||
"""Return the Redis stream key holding one run's event log."""
|
||||
return f"{prefix}:runs:{run_id}:events"
|
||||
|
||||
|
||||
__all__ = ["run_events_key", "run_record_key"]
|
||||
@ -1,143 +0,0 @@
|
||||
"""Redis-backed run records and per-run event streams.
|
||||
|
||||
The store writes run records as JSON strings and events as Redis streams. HTTP
|
||||
event cursors are Redis stream ids; ``0-0`` means replay from the beginning for
|
||||
polling and SSE. Records and streams share one retention window that is refreshed
|
||||
when status or event data is written. Execution is scheduled in-process by
|
||||
``dify_agent.runtime.run_scheduler``; Redis is not a job queue.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import cast
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from dify_agent.runtime.event_sink import RunEventSink
|
||||
from dify_agent.server.schemas import (
|
||||
CreateRunRequest,
|
||||
RUN_EVENT_ADAPTER,
|
||||
RunEvent,
|
||||
RunEventsResponse,
|
||||
RunRecord,
|
||||
RunStatus,
|
||||
new_run_id,
|
||||
utc_now,
|
||||
)
|
||||
from dify_agent.server.settings import DEFAULT_RUN_RETENTION_SECONDS
|
||||
from dify_agent.storage.redis_keys import run_events_key, run_record_key
|
||||
|
||||
|
||||
class RunNotFoundError(LookupError):
|
||||
"""Raised when a requested run record does not exist."""
|
||||
|
||||
|
||||
class RedisRunStore(RunEventSink):
|
||||
"""Async Redis implementation for run records and event logs.
|
||||
|
||||
``run_retention_seconds`` is applied to both the run record key and the
|
||||
per-run Redis stream. Event writes also refresh the record TTL so long-running
|
||||
runs that keep producing events do not lose their status record mid-run.
|
||||
"""
|
||||
|
||||
redis: Redis
|
||||
prefix: str
|
||||
run_retention_seconds: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: Redis,
|
||||
*,
|
||||
prefix: str = "dify-agent",
|
||||
run_retention_seconds: int = DEFAULT_RUN_RETENTION_SECONDS,
|
||||
) -> None:
|
||||
if run_retention_seconds <= 0:
|
||||
raise ValueError("run_retention_seconds must be positive")
|
||||
self.redis = redis
|
||||
self.prefix = prefix
|
||||
self.run_retention_seconds = run_retention_seconds
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Persist a running run record without enqueueing external work."""
|
||||
run_id = new_run_id()
|
||||
record = RunRecord(run_id=run_id, status="running", request=request)
|
||||
await self.redis.set(
|
||||
run_record_key(self.prefix, run_id),
|
||||
record.model_dump_json(),
|
||||
ex=self.run_retention_seconds,
|
||||
)
|
||||
return record
|
||||
|
||||
async def get_run(self, run_id: str) -> RunRecord:
|
||||
"""Return one run record or raise ``RunNotFoundError``."""
|
||||
value = await self.redis.get(run_record_key(self.prefix, run_id))
|
||||
if value is None:
|
||||
raise RunNotFoundError(run_id)
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
return RunRecord.model_validate_json(value)
|
||||
|
||||
async def update_status(self, run_id: str, status: RunStatus, error: str | None = None) -> None:
|
||||
"""Update the status fields of an existing run record."""
|
||||
record = await self.get_run(run_id)
|
||||
updated = record.model_copy(update={"status": status, "updated_at": utc_now(), "error": error})
|
||||
await self.redis.set(
|
||||
run_record_key(self.prefix, run_id),
|
||||
updated.model_dump_json(),
|
||||
ex=self.run_retention_seconds,
|
||||
)
|
||||
|
||||
async def append_event(self, event: RunEvent) -> str:
|
||||
"""Append an event JSON payload to the run's Redis stream."""
|
||||
events_key = run_events_key(self.prefix, event.run_id)
|
||||
payload = RUN_EVENT_ADAPTER.dump_json(event, exclude={"id"}).decode()
|
||||
event_id = await self.redis.xadd(
|
||||
events_key,
|
||||
{"payload": payload},
|
||||
)
|
||||
await self.redis.expire(events_key, self.run_retention_seconds)
|
||||
await self.redis.expire(run_record_key(self.prefix, event.run_id), self.run_retention_seconds)
|
||||
return event_id.decode() if isinstance(event_id, bytes) else str(event_id)
|
||||
|
||||
async def get_events(self, run_id: str, *, after: str = "0-0", limit: int = 100) -> RunEventsResponse:
|
||||
"""Read a bounded page of events after ``after`` cursor."""
|
||||
await self.get_run(run_id)
|
||||
raw_events = await self.redis.xrange(run_events_key(self.prefix, run_id), min=f"({after}", count=limit)
|
||||
events = [self._decode_event(run_id, raw_id, fields) for raw_id, fields in raw_events]
|
||||
next_cursor = events[-1].id if events else after
|
||||
return RunEventsResponse(run_id=run_id, events=events, next_cursor=next_cursor)
|
||||
|
||||
async def iter_events(self, run_id: str, *, after: str = "0-0") -> AsyncIterator[RunEvent]:
|
||||
"""Yield replayed and future events for SSE clients."""
|
||||
await self.get_run(run_id)
|
||||
cursor = after
|
||||
while True:
|
||||
page = await self.get_events(run_id, after=cursor, limit=100)
|
||||
for event in page.events:
|
||||
if event.id is not None:
|
||||
cursor = event.id
|
||||
yield event
|
||||
if not page.events:
|
||||
break
|
||||
while True:
|
||||
response = await self.redis.xread({run_events_key(self.prefix, run_id): cursor}, block=30_000, count=100)
|
||||
if not response:
|
||||
continue
|
||||
for _stream_name, entries in response:
|
||||
for raw_id, fields in entries:
|
||||
event = self._decode_event(run_id, raw_id, fields)
|
||||
if event.id is not None:
|
||||
cursor = event.id
|
||||
yield event
|
||||
|
||||
@staticmethod
|
||||
def _decode_event(run_id: str, raw_id: object, fields: dict[object, object]) -> RunEvent:
|
||||
"""Decode one Redis stream entry into a public event."""
|
||||
payload = fields.get(b"payload") or fields.get("payload")
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode()
|
||||
event_id = raw_id.decode() if isinstance(raw_id, bytes) else str(raw_id)
|
||||
event = RUN_EVENT_ADAPTER.validate_json(cast(str, payload))
|
||||
return event.model_copy(update={"id": event_id, "run_id": run_id})
|
||||
|
||||
|
||||
__all__ = ["DEFAULT_RUN_RETENTION_SECONDS", "RedisRunStore", "RunNotFoundError"]
|
||||
@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from _pytest.mark import ParameterSet
|
||||
from pytest_examples import CodeExample, EvalExample, find_examples
|
||||
from pytest_examples.config import ExamplesConfig as BaseExamplesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExamplesConfig(BaseExamplesConfig):
|
||||
known_first_party: list[str] = field(default_factory=list[str])
|
||||
|
||||
def ruff_config(self) -> tuple[str, ...]:
|
||||
config = super().ruff_config()
|
||||
if self.known_first_party:
|
||||
config = (*config, "--config", f"lint.isort.known-first-party = {self.known_first_party}")
|
||||
return config
|
||||
|
||||
|
||||
def find_doc_examples() -> Iterable[ParameterSet]:
|
||||
root_dir = Path(__file__).resolve().parents[2]
|
||||
for example in find_examples(
|
||||
root_dir / "docs",
|
||||
root_dir / "src",
|
||||
root_dir / "examples" / "agenton",
|
||||
root_dir / "examples" / "dify_agent",
|
||||
):
|
||||
path = example.path.relative_to(root_dir)
|
||||
yield pytest.param(example, id=f"{path}:{example.start_line}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("example", find_doc_examples())
|
||||
def test_documentation_examples(example: CodeExample, eval_example: EvalExample) -> None:
|
||||
prefix_settings = example.prefix_settings()
|
||||
opt_test = prefix_settings.get("test", "")
|
||||
opt_lint = prefix_settings.get("lint", "")
|
||||
line_length = int(prefix_settings.get("line_length", "120"))
|
||||
|
||||
eval_example.config = ExamplesConfig(
|
||||
ruff_ignore=["D", "Q001"],
|
||||
target_version="py312", # pyright: ignore[reportArgumentType]
|
||||
line_length=line_length,
|
||||
isort=True,
|
||||
upgrade=True,
|
||||
quotes="double",
|
||||
known_first_party=["agenton", "agenton_collections", "dify_agent"],
|
||||
)
|
||||
|
||||
if not opt_lint.startswith("skip"):
|
||||
if eval_example.update_examples: # pragma: no cover
|
||||
eval_example.format_ruff(example)
|
||||
else:
|
||||
eval_example.lint_ruff(example)
|
||||
|
||||
if opt_test.startswith("skip"):
|
||||
pytest.skip(opt_test[4:].lstrip(" -") or "running code skipped")
|
||||
|
||||
if eval_example.update_examples: # pragma: no cover
|
||||
eval_example.run_print_update(example, module_globals={"__name__": "__main__"})
|
||||
else:
|
||||
eval_example.run_print_check(example, module_globals={"__name__": "__main__"})
|
||||
@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
HOOKS_DIR = PROJECT_ROOT / "docs" / ".hooks"
|
||||
sys.path.append(str(HOOKS_DIR))
|
||||
|
||||
from snippets import inject_snippets, parse_file_sections, parse_snippet_directive # pyright: ignore[reportMissingImports] # noqa: E402
|
||||
|
||||
|
||||
def test_parse_snippet_directive() -> None:
|
||||
directive = parse_snippet_directive('```snippet {path="demo.py" fragment="main" hl="1"}\n```')
|
||||
|
||||
assert directive is not None
|
||||
assert directive.path == "demo.py"
|
||||
assert directive.fragment == "main"
|
||||
assert directive.extra_attrs == {"hl": "1"}
|
||||
|
||||
|
||||
def test_parse_file_sections_and_inject_snippet(tmp_path: Path) -> None:
|
||||
source = tmp_path / "demo.py"
|
||||
source.write_text(
|
||||
"""import asyncio
|
||||
|
||||
### [main]
|
||||
async def main() -> None:
|
||||
print("hello")
|
||||
### [/main]
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
parsed = parse_file_sections(source)
|
||||
assert "main" in parsed.sections
|
||||
|
||||
markdown = '```snippet {path="/examples/agenton/agenton_examples/session_snapshot.py"}\n```'
|
||||
rendered = inject_snippets(markdown, PROJECT_ROOT / "docs")
|
||||
|
||||
assert rendered.startswith('```py {title="examples/agenton/agenton_examples/session_snapshot.py"}')
|
||||
assert "async def main() -> None:" in rendered
|
||||
assert "asyncio.run(main())" in rendered
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,258 +0,0 @@
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorBuilder, CompositorSession, LayerRegistry
|
||||
from agenton.layers import EmptyLayerConfig, LayerControl, LayerDeps, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
|
||||
from agenton_collections.layers.plain import ObjectLayer, PromptLayer
|
||||
|
||||
|
||||
def test_registry_infers_descriptor_and_rejects_duplicate_or_missing_type_id() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
descriptor = registry.resolve("plain.prompt")
|
||||
assert descriptor.layer_type is PromptLayer
|
||||
assert descriptor.config_type is PromptLayer.config_type
|
||||
|
||||
try:
|
||||
registry.register_layer(PromptLayer)
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer type id 'plain.prompt' is already registered."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
registry.register_layer(InstanceOnlyLayer)
|
||||
except ValueError as e:
|
||||
assert "must declare a type_id" in str(e)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
registry.register_layer(InstanceOnlyLayer, type_id=123) # pyright: ignore[reportArgumentType]
|
||||
except TypeError as e:
|
||||
assert str(e) == "Layer type id for 'InstanceOnlyLayer' must be a string."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InstanceOnlyLayer(PlainLayer[NoLayerDeps]):
|
||||
pass
|
||||
|
||||
|
||||
def test_builder_creates_config_layers_with_typed_validation() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config_layer(
|
||||
name="prompt",
|
||||
type="plain.prompt",
|
||||
config={"prefix": "hello", "user": "ask politely", "suffix": ["bye"]},
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert [prompt.value for prompt in compositor.prompts] == ["hello", "bye"]
|
||||
assert [prompt.value for prompt in compositor.user_prompts] == ["ask politely"]
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_config_layer(
|
||||
name="bad",
|
||||
type="plain.prompt",
|
||||
config={"unknown": "field"},
|
||||
)
|
||||
except ValidationError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Expected ValidationError.")
|
||||
|
||||
|
||||
class ObjectConsumerDeps(LayerDeps):
|
||||
obj: ObjectLayer[str] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ObjectConsumerLayer(PlainLayer[ObjectConsumerDeps]):
|
||||
@property
|
||||
@override
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
return [self.deps.obj.value]
|
||||
|
||||
|
||||
def test_builder_mixes_config_and_instances_and_rejects_invalid_deps() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "cfg"}}]})
|
||||
.add_instance(name="obj", layer=ObjectLayer("instance"))
|
||||
.add_instance(name="consumer", layer=ObjectConsumerLayer(), deps={"obj": "obj"})
|
||||
.build()
|
||||
)
|
||||
|
||||
assert [prompt.value for prompt in compositor.prompts] == ["cfg", "instance"]
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_instance(
|
||||
name="consumer",
|
||||
layer=ObjectConsumerLayer(),
|
||||
deps={"missing_dep_key": "obj"},
|
||||
).build()
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer 'consumer' declares unknown dependency keys: missing_dep_key."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_instance(
|
||||
name="consumer",
|
||||
layer=ObjectConsumerLayer(),
|
||||
deps={"obj": "missing_target"},
|
||||
).build()
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer 'consumer' depends on undefined layer names: missing_target."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
|
||||
class HandleState(BaseModel):
|
||||
resource_id: str = ""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class HandleBox:
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
|
||||
class HandleModels(BaseModel):
|
||||
handle: HandleBox | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HandleLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, HandleState, HandleModels]):
|
||||
created: int = 0
|
||||
resumed: int = 0
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl[HandleState, HandleModels]) -> None:
|
||||
self.created += 1
|
||||
control.runtime_handles.handle = HandleBox(control.runtime_state.resource_id)
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl[HandleState, HandleModels]) -> None:
|
||||
self.resumed += 1
|
||||
control.runtime_handles.handle = HandleBox(f"resumed:{control.runtime_state.resource_id}")
|
||||
|
||||
|
||||
def test_new_session_uses_layer_runtime_schemas() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
session = compositor.new_session()
|
||||
|
||||
assert isinstance(session.layer("handle").runtime_state, HandleState)
|
||||
assert isinstance(session.layer("handle").runtime_handles, HandleModels)
|
||||
|
||||
|
||||
def test_enter_rejects_bad_session_runtime_schemas_before_layer_hooks() -> None:
|
||||
layer = HandleLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
|
||||
bad_session = CompositorSession(OrderedDict([("handle", LayerControl())]))
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(bad_session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except TypeError as e:
|
||||
assert str(e) == (
|
||||
"CompositorSession layer 'handle' runtime_state must be HandleState, "
|
||||
"got EmptyRuntimeState."
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
assert layer.created == 0
|
||||
|
||||
|
||||
def test_snapshot_rejects_active_sessions_and_excludes_handles() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
session = compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "new", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
try:
|
||||
compositor.snapshot_session(session)
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "Cannot snapshot active compositor session layers: handle."
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
snapshot = compositor.snapshot_session(session)
|
||||
assert snapshot.model_dump(mode="json") == {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "handle", "state": "closed", "runtime_state": {"resource_id": "abc"}}],
|
||||
}
|
||||
|
||||
|
||||
def test_restore_validates_runtime_state_and_resume_rehydrates_handles() -> None:
|
||||
layer = HandleLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
|
||||
|
||||
try:
|
||||
compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"wrong": "field"}}]}
|
||||
)
|
||||
except ValidationError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Expected ValidationError.")
|
||||
|
||||
restored = compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(restored):
|
||||
control = restored.layer("handle")
|
||||
assert isinstance(control.runtime_handles, HandleModels)
|
||||
assert control.runtime_handles.handle is not None
|
||||
assert control.runtime_handles.handle.value == "resumed:abc"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert layer.resumed == 1
|
||||
|
||||
|
||||
def test_session_from_snapshot_rejects_active_layer_state() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
|
||||
try:
|
||||
compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "active", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
except ValueError as e:
|
||||
assert str(e) == "Cannot restore active compositor session layers from snapshot: handle."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
@ -1,298 +0,0 @@
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import count
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorSession
|
||||
from agenton.layers import (
|
||||
ExitIntent,
|
||||
EmptyLayerConfig,
|
||||
EmptyRuntimeHandles,
|
||||
LayerControl,
|
||||
LifecycleState,
|
||||
NoLayerDeps,
|
||||
PlainLayer,
|
||||
PlainPromptType,
|
||||
PlainToolType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraceLayer(PlainLayer[NoLayerDeps]):
|
||||
"""Layer that records lifecycle events observable to tests."""
|
||||
|
||||
events: list[str] = field(default_factory=list)
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
self.events.append("create")
|
||||
|
||||
@override
|
||||
async def on_context_suspend(self, control: LayerControl) -> None:
|
||||
self.events.append("suspend")
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
self.events.append("resume")
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
self.events.append("delete")
|
||||
|
||||
|
||||
def _compositor(*layer_names: str) -> tuple[Compositor[PlainPromptType, PlainToolType], dict[str, TraceLayer]]:
|
||||
layers = {layer_name: TraceLayer() for layer_name in layer_names}
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict(layers.items()))
|
||||
return compositor, layers
|
||||
|
||||
|
||||
def test_compositor_session_suspends_resumes_and_deletes_all_layers() -> None:
|
||||
compositor, layers = _compositor("first", "second")
|
||||
session = compositor.new_session()
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session) as active_session:
|
||||
assert active_session is session
|
||||
assert list(active_session.layer_controls) == ["first", "second"]
|
||||
active_session.suspend_on_exit()
|
||||
assert active_session.layer("first").exit_intent is ExitIntent.SUSPEND
|
||||
|
||||
assert session.layer("first").state is LifecycleState.SUSPENDED
|
||||
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert layers["first"].events == ["create", "suspend", "resume", "delete"]
|
||||
assert layers["second"].events == ["create", "suspend", "resume", "delete"]
|
||||
assert session.layer("first").state is LifecycleState.CLOSED
|
||||
|
||||
|
||||
def test_compositor_enter_without_session_uses_fresh_lifecycle_each_time() -> None:
|
||||
compositor, layers = _compositor("trace")
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter() as session:
|
||||
session.suspend_on_exit()
|
||||
|
||||
async with compositor.enter():
|
||||
pass
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert layers["trace"].events == ["create", "suspend", "create", "delete"]
|
||||
|
||||
|
||||
def test_compositor_enter_rejects_session_with_mismatched_layer_names() -> None:
|
||||
compositor, _layers = _compositor("trace")
|
||||
session = CompositorSession(["other"])
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except ValueError as e:
|
||||
assert str(e) == (
|
||||
"CompositorSession layer names must match compositor layers in order. "
|
||||
"Expected [trace], got [other]."
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
|
||||
def test_compositor_enter_rejects_same_active_session_nested() -> None:
|
||||
compositor, _layers = _compositor("trace")
|
||||
session = compositor.new_session()
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "LayerControl is already active; duplicate or nested enter is not allowed."
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
|
||||
def test_compositor_enter_rejects_closed_session() -> None:
|
||||
compositor, _layers = _compositor("trace")
|
||||
session = compositor.new_session()
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "LayerControl is closed; create a new compositor session before entering again."
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
|
||||
def test_per_layer_suspend_on_exit_only_resumes_that_layer() -> None:
|
||||
compositor, layers = _compositor("first", "second")
|
||||
session = compositor.new_session()
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
session.layer("first").suspend_on_exit()
|
||||
|
||||
assert session.layer("first").state is LifecycleState.SUSPENDED
|
||||
assert session.layer("second").state is LifecycleState.CLOSED
|
||||
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "LayerControl is closed; create a new compositor session before entering again."
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
assert layers["first"].events == ["create", "suspend"]
|
||||
assert layers["second"].events == ["create", "delete"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FailingCreateLayer(PlainLayer[NoLayerDeps]):
|
||||
attempts: int = 0
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
self.attempts += 1
|
||||
if self.attempts == 1:
|
||||
raise RuntimeError("create failed")
|
||||
|
||||
|
||||
def test_failed_create_keeps_control_reusable_as_new() -> None:
|
||||
layer = FailingCreateLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("trace", layer)]))
|
||||
session = compositor.new_session()
|
||||
|
||||
async def fail_then_retry() -> None:
|
||||
try:
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "create failed"
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
assert session.layer("trace").state is LifecycleState.NEW
|
||||
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
asyncio.run(fail_then_retry())
|
||||
|
||||
assert session.layer("trace").state is LifecycleState.CLOSED
|
||||
assert layer.attempts == 2
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FailingResumeLayer(PlainLayer[NoLayerDeps]):
|
||||
resumed: bool = False
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
if not self.resumed:
|
||||
self.resumed = True
|
||||
raise RuntimeError("resume failed")
|
||||
|
||||
|
||||
def test_failed_resume_keeps_control_reusable_as_suspended() -> None:
|
||||
layer = FailingResumeLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("trace", layer)]))
|
||||
session = compositor.new_session()
|
||||
|
||||
async def suspend_fail_then_retry() -> None:
|
||||
async with compositor.enter(session) as active_session:
|
||||
active_session.suspend_on_exit()
|
||||
|
||||
try:
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "resume failed"
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
assert session.layer("trace").state is LifecycleState.SUSPENDED
|
||||
|
||||
async with compositor.enter(session):
|
||||
pass
|
||||
|
||||
asyncio.run(suspend_fail_then_retry())
|
||||
|
||||
assert session.layer("trace").state is LifecycleState.CLOSED
|
||||
|
||||
|
||||
class RuntimeState(BaseModel):
|
||||
runtime_id: int | None = None
|
||||
resumed_runtime_id: int | None = None
|
||||
deleted_runtime_id: int | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RuntimeStateLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, RuntimeState]):
|
||||
next_id: Iterator[int] = field(default_factory=lambda: count(1))
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
runtime_id = next(self.next_id)
|
||||
control.runtime_state.runtime_id = runtime_id
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
control.runtime_state.resumed_runtime_id = control.runtime_state.runtime_id
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
control.runtime_state.deleted_runtime_id = control.runtime_state.runtime_id
|
||||
|
||||
|
||||
def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> None:
|
||||
layer = RuntimeStateLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("trace", layer)]))
|
||||
first_session = compositor.new_session()
|
||||
second_session = compositor.new_session()
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(first_session) as active_session:
|
||||
active_session.suspend_on_exit()
|
||||
|
||||
async with compositor.enter(second_session):
|
||||
pass
|
||||
|
||||
async with compositor.enter(first_session):
|
||||
pass
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert first_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
|
||||
"runtime_id": 1,
|
||||
"resumed_runtime_id": 1,
|
||||
"deleted_runtime_id": 1,
|
||||
}
|
||||
assert second_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
|
||||
"runtime_id": 2,
|
||||
"deleted_runtime_id": 2,
|
||||
}
|
||||
assert not hasattr(layer, "runtime_id")
|
||||
@ -1,163 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from inspect import Parameter, signature
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorTransformerKwargs
|
||||
from agenton.layers import NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType, PlainUserPromptType
|
||||
|
||||
type ToolCallable = Callable[..., object]
|
||||
type WrappedPrompt = tuple[str, str]
|
||||
type WrappedUserPrompt = tuple[str, str]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PromptAndToolLayer(PlainLayer[NoLayerDeps]):
|
||||
prefix: list[str]
|
||||
user: list[str]
|
||||
suffix: list[str]
|
||||
tool_entries: list[ToolCallable]
|
||||
|
||||
@property
|
||||
@override
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
return self.prefix
|
||||
|
||||
@property
|
||||
@override
|
||||
def suffix_prompts(self) -> list[str]:
|
||||
return self.suffix
|
||||
|
||||
@property
|
||||
@override
|
||||
def user_prompts(self) -> list[str]:
|
||||
return self.user
|
||||
|
||||
@property
|
||||
@override
|
||||
def tools(self) -> list[ToolCallable]:
|
||||
return self.tool_entries
|
||||
|
||||
|
||||
def base_tool() -> str:
|
||||
return "base"
|
||||
|
||||
|
||||
def wrapped_tool() -> str:
|
||||
return "wrapped"
|
||||
|
||||
|
||||
def wrap_prompts(prompts: Sequence[PlainPromptType]) -> list[WrappedPrompt]:
|
||||
return [("wrapped", prompt.value) for prompt in prompts]
|
||||
|
||||
|
||||
def wrap_user_prompts(prompts: Sequence[PlainUserPromptType]) -> list[WrappedUserPrompt]:
|
||||
return [("wrapped-user", prompt.value) for prompt in prompts]
|
||||
|
||||
|
||||
def describe_tools(tools: Sequence[PlainToolType]) -> list[str]:
|
||||
return [tool.value.__name__ for tool in tools]
|
||||
|
||||
|
||||
def test_compositor_transformer_kwargs_keys_match_constructor_parameters() -> None:
|
||||
transformer_kwargs = set(CompositorTransformerKwargs.__required_keys__)
|
||||
parameters = signature(Compositor).parameters
|
||||
|
||||
assert CompositorTransformerKwargs.__optional_keys__ == frozenset()
|
||||
assert transformer_kwargs == {
|
||||
name for name in parameters if name.endswith("_transformer")
|
||||
}
|
||||
assert all(parameters[name].kind is Parameter.KEYWORD_ONLY for name in transformer_kwargs)
|
||||
|
||||
|
||||
def test_compositor_transformer_kwargs_keys_match_from_config_parameters() -> None:
|
||||
transformer_kwargs = set(CompositorTransformerKwargs.__required_keys__)
|
||||
parameters = signature(Compositor.from_config).parameters
|
||||
|
||||
assert transformer_kwargs == {
|
||||
name for name in parameters if name.endswith("_transformer")
|
||||
}
|
||||
assert all(parameters[name].kind is Parameter.KEYWORD_ONLY for name in transformer_kwargs)
|
||||
|
||||
|
||||
def test_compositor_transforms_prompts_to_another_type_after_layer_ordering() -> None:
|
||||
compositor: Compositor[WrappedPrompt, PlainToolType, PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict(
|
||||
[
|
||||
(
|
||||
"first",
|
||||
PromptAndToolLayer(
|
||||
prefix=["first-prefix"],
|
||||
user=[],
|
||||
suffix=["first-suffix"],
|
||||
tool_entries=[],
|
||||
),
|
||||
),
|
||||
(
|
||||
"second",
|
||||
PromptAndToolLayer(
|
||||
prefix=["second-prefix"],
|
||||
user=[],
|
||||
suffix=["second-suffix"],
|
||||
tool_entries=[],
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
prompt_transformer=wrap_prompts,
|
||||
)
|
||||
|
||||
assert compositor.prompts == [
|
||||
("wrapped", "first-prefix"),
|
||||
("wrapped", "second-prefix"),
|
||||
("wrapped", "second-suffix"),
|
||||
("wrapped", "first-suffix"),
|
||||
]
|
||||
|
||||
|
||||
def test_compositor_transforms_tools_to_another_type_after_layer_aggregation() -> None:
|
||||
compositor: Compositor[PlainPromptType, str, PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict(
|
||||
[
|
||||
(
|
||||
"tools",
|
||||
PromptAndToolLayer(prefix=[], user=[], suffix=[], tool_entries=[base_tool, wrapped_tool]),
|
||||
)
|
||||
]
|
||||
),
|
||||
tool_transformer=describe_tools,
|
||||
)
|
||||
|
||||
assert compositor.tools == ["base_tool", "wrapped_tool"]
|
||||
|
||||
|
||||
def test_compositor_transforms_user_prompts_after_layer_ordering() -> None:
|
||||
compositor: Compositor[
|
||||
PlainPromptType,
|
||||
PlainToolType,
|
||||
PlainPromptType,
|
||||
PlainToolType,
|
||||
WrappedUserPrompt,
|
||||
PlainUserPromptType,
|
||||
] = Compositor(
|
||||
layers=OrderedDict(
|
||||
[
|
||||
(
|
||||
"first",
|
||||
PromptAndToolLayer(prefix=[], user=["first-user"], suffix=[], tool_entries=[]),
|
||||
),
|
||||
(
|
||||
"second",
|
||||
PromptAndToolLayer(prefix=[], user=["second-user"], suffix=[], tool_entries=[]),
|
||||
),
|
||||
]
|
||||
),
|
||||
user_prompt_transformer=wrap_user_prompts,
|
||||
)
|
||||
|
||||
assert compositor.user_prompts == [
|
||||
("wrapped-user", "first-user"),
|
||||
("wrapped-user", "second-user"),
|
||||
]
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from agenton.layers import LayerDeps
|
||||
from agenton_collections.layers.plain import ObjectLayer, PromptLayer
|
||||
|
||||
|
||||
class ObjectLayerDeps(LayerDeps):
|
||||
"""Deps container used to exercise runtime dependency validation."""
|
||||
|
||||
object_layer: ObjectLayer[str] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
def test_layer_deps_rejects_mismatched_runtime_layer_class() -> None:
|
||||
with pytest.raises(TypeError, match="should be of type 'ObjectLayer'"):
|
||||
ObjectLayerDeps(object_layer=PromptLayer())
|
||||
@ -1,94 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from agenton.compositor import LayerRegistry
|
||||
from agenton.layers import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, LayerControl, NoLayerDeps, PlainLayer
|
||||
|
||||
|
||||
class InferredConfig(BaseModel):
|
||||
value: str = "configured"
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class InferredState(BaseModel):
|
||||
count: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class InferredHandles(BaseModel):
|
||||
token: object | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenericSchemaLayer(PlainLayer[NoLayerDeps, InferredConfig, InferredState, InferredHandles]):
|
||||
type_id = "test.generic-schema"
|
||||
|
||||
async def on_context_create(self, control: LayerControl[InferredState, InferredHandles]) -> None:
|
||||
control.runtime_state.count += 1
|
||||
control.runtime_handles.token = object()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DefaultSchemaLayer(PlainLayer[NoLayerDeps]):
|
||||
type_id = "test.default-schema"
|
||||
|
||||
|
||||
def test_layer_infers_config_runtime_state_and_handles_from_generics() -> None:
|
||||
layer = GenericSchemaLayer()
|
||||
control = layer.new_control(runtime_state={"count": 3})
|
||||
|
||||
assert GenericSchemaLayer.config_type is InferredConfig
|
||||
assert GenericSchemaLayer.runtime_state_type is InferredState
|
||||
assert GenericSchemaLayer.runtime_handles_type is InferredHandles
|
||||
assert isinstance(control.runtime_state, InferredState)
|
||||
assert control.runtime_state.count == 3
|
||||
assert isinstance(control.runtime_handles, InferredHandles)
|
||||
|
||||
|
||||
def test_layer_uses_empty_schema_defaults_when_omitted() -> None:
|
||||
layer = DefaultSchemaLayer()
|
||||
control = layer.new_control()
|
||||
|
||||
assert DefaultSchemaLayer.config_type is EmptyLayerConfig
|
||||
assert DefaultSchemaLayer.runtime_state_type is EmptyRuntimeState
|
||||
assert DefaultSchemaLayer.runtime_handles_type is EmptyRuntimeHandles
|
||||
assert isinstance(control.runtime_state, EmptyRuntimeState)
|
||||
assert isinstance(control.runtime_handles, EmptyRuntimeHandles)
|
||||
|
||||
|
||||
def test_invalid_declared_schema_type_is_rejected_clearly() -> None:
|
||||
try:
|
||||
|
||||
class InvalidSchemaLayer(PlainLayer[NoLayerDeps]):
|
||||
config_type = dict # pyright: ignore[reportAssignmentType]
|
||||
|
||||
except TypeError as e:
|
||||
assert str(e) == "InvalidSchemaLayer.config_type must be a Pydantic BaseModel subclass."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
try:
|
||||
|
||||
class InvalidGenericSchemaLayer(PlainLayer[NoLayerDeps, dict[str, object]]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
pass
|
||||
|
||||
except TypeError as e:
|
||||
assert str(e) == "InvalidGenericSchemaLayer.config_type must be a Pydantic BaseModel subclass."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
|
||||
def test_registry_descriptor_uses_inferred_schema_types() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(GenericSchemaLayer)
|
||||
|
||||
descriptor = registry.resolve("test.generic-schema")
|
||||
|
||||
assert descriptor.config_type is InferredConfig
|
||||
assert descriptor.runtime_state_type is InferredState
|
||||
assert descriptor.runtime_handles_type is InferredHandles
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from pydantic_ai import Tool
|
||||
|
||||
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, with_object
|
||||
|
||||
|
||||
class Profile:
|
||||
"""Profile object used by object-bound tool tests."""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
|
||||
class OtherProfile:
|
||||
"""Different runtime object used to trigger object mismatch checks."""
|
||||
|
||||
|
||||
@with_object(Profile)
|
||||
def greet(profile: Profile, topic: str) -> str:
|
||||
return f"{profile.name}: {topic}"
|
||||
|
||||
|
||||
def test_with_object_rejects_tool_without_object_parameter() -> None:
|
||||
def tool() -> str:
|
||||
return "unused"
|
||||
|
||||
with pytest.raises(ValueError, match="must accept the object dependency"):
|
||||
with_object(Profile)(tool) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def test_with_object_rejects_first_parameter_annotation_mismatch() -> None:
|
||||
def tool(profile: OtherProfile) -> str:
|
||||
return repr(profile)
|
||||
|
||||
with pytest.raises(TypeError, match="first parameter should accept 'Profile'"):
|
||||
with_object(Profile)(tool) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def test_dynamic_tools_layer_rejects_mismatched_runtime_object_value() -> None:
|
||||
layer = DynamicToolsLayer[Profile](tool_entries=(greet,))
|
||||
layer.bind_deps({"object_layer": ObjectLayer[OtherProfile](OtherProfile())})
|
||||
|
||||
with pytest.raises(TypeError, match="expected object dependency of type 'Profile'"):
|
||||
layer.tools
|
||||
|
||||
|
||||
def public_greet(topic: str) -> str:
|
||||
return f"Ada: {topic}"
|
||||
|
||||
|
||||
def test_dynamic_tools_layer_binds_object_as_pydantic_ai_equivalent_tool() -> None:
|
||||
layer = DynamicToolsLayer[Profile](tool_entries=(greet,))
|
||||
layer.bind_deps({"object_layer": ObjectLayer[Profile](Profile("Ada"))})
|
||||
|
||||
expected_tool = Tool(public_greet, name="greet")
|
||||
dynamic_tool = Tool(layer.tools[0], name="greet")
|
||||
dynamic_result = asyncio.run(
|
||||
dynamic_tool.function_schema.call(
|
||||
{"topic": "layer composition"},
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
expected_result = asyncio.run(
|
||||
expected_tool.function_schema.call(
|
||||
{"topic": "layer composition"},
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
|
||||
assert dynamic_tool.tool_def == expected_tool.tool_def
|
||||
assert dynamic_result == expected_result
|
||||
@ -1,56 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from pydantic_ai import RunContext, Tool
|
||||
|
||||
from agenton_collections.layers.pydantic_ai import PydanticAIBridgeLayer
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Profile:
|
||||
name: str
|
||||
|
||||
|
||||
def profile_prompt(ctx: RunContext[Profile]) -> str:
|
||||
return f"Profile: {ctx.deps.name}"
|
||||
|
||||
|
||||
def existing_tool(ctx: RunContext[Profile]) -> str:
|
||||
return ctx.deps.name
|
||||
|
||||
|
||||
def raw_tool(ctx: RunContext[Profile], topic: str) -> str:
|
||||
return f"{ctx.deps.name}: {topic}"
|
||||
|
||||
|
||||
def test_pydantic_ai_bridge_layer_accepts_mixed_string_and_function_prompts() -> None:
|
||||
layer = PydanticAIBridgeLayer[Profile](
|
||||
prefix=("plain prefix", profile_prompt),
|
||||
user=("first user", "second user"),
|
||||
suffix="plain suffix",
|
||||
)
|
||||
|
||||
prefix_prompts = layer.prefix_prompts
|
||||
user_prompts = layer.user_prompts
|
||||
suffix_prompts = layer.suffix_prompts
|
||||
|
||||
plain_prefix = cast(Callable[[], str], prefix_prompts[0])
|
||||
plain_suffix = cast(Callable[[], str], suffix_prompts[0])
|
||||
assert plain_prefix() == "plain prefix"
|
||||
assert prefix_prompts[1] is profile_prompt
|
||||
assert user_prompts == ["first user", "second user"]
|
||||
assert plain_suffix() == "plain suffix"
|
||||
|
||||
|
||||
def test_pydantic_ai_bridge_layer_accepts_mixed_tool_and_tool_function_entries() -> None:
|
||||
pydantic_ai_tool = Tool(existing_tool)
|
||||
layer = PydanticAIBridgeLayer[Profile](
|
||||
tool_entries=(pydantic_ai_tool, raw_tool),
|
||||
)
|
||||
|
||||
tools = layer.tools
|
||||
|
||||
assert tools[0] is pydantic_ai_tool
|
||||
assert isinstance(tools[1], Tool)
|
||||
assert tools[1].function is raw_tool
|
||||
@ -1,85 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from pydantic_ai import Tool
|
||||
|
||||
from agenton.layers.types import (
|
||||
PlainPromptType,
|
||||
PlainToolType,
|
||||
PlainUserPromptType,
|
||||
PydanticAIPromptType,
|
||||
PydanticAIToolType,
|
||||
PydanticAIUserPromptType,
|
||||
)
|
||||
from agenton_collections.transformers.pydantic_ai import PYDANTIC_AI_TRANSFORMERS
|
||||
|
||||
|
||||
def plain_tool(name: str) -> str:
|
||||
return f"hello {name}"
|
||||
|
||||
|
||||
def dynamic_prompt() -> str:
|
||||
return "dynamic prompt"
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_wrap_tagged_plain_prompts() -> None:
|
||||
prompts = [PlainPromptType("plain prompt")]
|
||||
|
||||
result = PYDANTIC_AI_TRANSFORMERS["prompt_transformer"](prompts)
|
||||
|
||||
assert len(result) == 1
|
||||
prompt_func = cast(Callable[[], str], result[0])
|
||||
assert prompt_func() == "plain prompt"
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_preserve_tagged_existing_prompt_functions() -> None:
|
||||
result = PYDANTIC_AI_TRANSFORMERS["prompt_transformer"]([PydanticAIPromptType(dynamic_prompt)])
|
||||
|
||||
assert result == [dynamic_prompt]
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_accept_mixed_tagged_prompt_types() -> None:
|
||||
result = PYDANTIC_AI_TRANSFORMERS["prompt_transformer"](
|
||||
[PlainPromptType("plain prompt"), PydanticAIPromptType(dynamic_prompt)]
|
||||
)
|
||||
|
||||
plain_prompt = cast(Callable[[], str], result[0])
|
||||
assert plain_prompt() == "plain prompt"
|
||||
assert result[1] is dynamic_prompt
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_accept_tagged_user_prompt_types() -> None:
|
||||
result = PYDANTIC_AI_TRANSFORMERS["user_prompt_transformer"](
|
||||
[PlainUserPromptType("plain user"), PydanticAIUserPromptType("pydantic user")]
|
||||
)
|
||||
|
||||
assert result == ["plain user", "pydantic user"]
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_wrap_tagged_plain_tools() -> None:
|
||||
result = PYDANTIC_AI_TRANSFORMERS["tool_transformer"]([PlainToolType(plain_tool)])
|
||||
|
||||
assert len(result) == 1
|
||||
tool = result[0]
|
||||
assert isinstance(tool, Tool)
|
||||
assert tool.function is plain_tool
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_preserve_tagged_existing_tools() -> None:
|
||||
pydantic_ai_tool = Tool(plain_tool)
|
||||
|
||||
result = PYDANTIC_AI_TRANSFORMERS["tool_transformer"]([PydanticAIToolType(pydantic_ai_tool)])
|
||||
|
||||
assert result == [pydantic_ai_tool]
|
||||
|
||||
|
||||
def test_pydantic_ai_transformers_accept_tagged_tool_types() -> None:
|
||||
pydantic_ai_tool = Tool(plain_tool)
|
||||
|
||||
result = PYDANTIC_AI_TRANSFORMERS["tool_transformer"](
|
||||
[PlainToolType(plain_tool), PydanticAIToolType(pydantic_ai_tool)]
|
||||
)
|
||||
|
||||
assert isinstance(result[0], Tool)
|
||||
assert result[0].function is plain_tool
|
||||
assert result[1] is pydantic_ai_tool
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user