mirror of
https://github.com/langgenius/dify.git
synced 2026-05-24 02:47:53 +08:00
Compare commits
107 Commits
codex/remo
...
feat/ui-on
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d12b5ab1c | |||
| 514dcae189 | |||
| 228dd84a91 | |||
| 336ddad096 | |||
| 92bb9a17b7 | |||
| b8868dab90 | |||
| 94225682cd | |||
| 18b6568c2a | |||
| a3a9ded29b | |||
| de78a26920 | |||
| c54d029e7c | |||
| ad4b9dc2c3 | |||
| cdec0c69a6 | |||
| 53acc3726c | |||
| b1d393f4d9 | |||
| 62e9bdd70d | |||
| d36c76c20e | |||
| f525e1a5eb | |||
| e2f779b20d | |||
| e198d6305c | |||
| 5e67514265 | |||
| b63896de87 | |||
| e463389f2c | |||
| cda348ca10 | |||
| ca48050666 | |||
| 9c0f592f34 | |||
| b70241ad36 | |||
| 4abe622b2e | |||
| 16c32c82e3 | |||
| 46424513d1 | |||
| 2c4baa20d8 | |||
| b0ae553f2e | |||
| 0266a12ee5 | |||
| 9d7765d5fd | |||
| d4ef983f42 | |||
| 018f36711d | |||
| dacd333e4a | |||
| b079a26314 | |||
| 7e953ebe0b | |||
| b4d28fca54 | |||
| 728c6b8201 | |||
| f56e23b5fd | |||
| 5600cefa53 | |||
| 561eb9cbd2 | |||
| 83766ca694 | |||
| 678be94d22 | |||
| 9e852429be | |||
| d93c5028f1 | |||
| 54f189305e | |||
| a610a24507 | |||
| 05e8a94bb5 | |||
| b2e2e7b60b | |||
| e7d2e66ff5 | |||
| c51069685c | |||
| 28c208f36a | |||
| 53a1386b87 | |||
| 0e366c7300 | |||
| 939bdde373 | |||
| 13dfa3aba4 | |||
| 2705a7c1db | |||
| 258a751b8c | |||
| 5a35d3d9cd | |||
| c3fbafae83 | |||
| f727c8f838 | |||
| 90af4c39b4 | |||
| f7c3a4e4cb | |||
| be7d043edd | |||
| cef8fe3a4b | |||
| afe0e6c393 | |||
| 37309b931e | |||
| 6a83c6705c | |||
| 3e75d5e443 | |||
| 7be8a5b883 | |||
| 80dcb344f4 | |||
| b029c9b1cd | |||
| 6cb97e9201 | |||
| 4ef2e952bd | |||
| cc5545339c | |||
| 0a8c46a3a7 | |||
| 65770903d1 | |||
| 5a6ba2ffb5 | |||
| aa53afe07d | |||
| 4740a89f4a | |||
| 328db3d67a | |||
| 88062fb247 | |||
| 045da59220 | |||
| 948b0f6bc7 | |||
| 14a59f6e44 | |||
| f9f361113e | |||
| eea6f59307 | |||
| 718f69dc43 | |||
| 82a2ba9264 | |||
| 6c8e032fbb | |||
| 28c2c3bfd3 | |||
| 9d463e1024 | |||
| 7f87616625 | |||
| 43a04ed0c2 | |||
| 5083edd0ce | |||
| 8306fa41b9 | |||
| 8f33305e90 | |||
| 7077a43c1c | |||
| 884a43ae0a | |||
| 914f89f478 | |||
| 163153db18 | |||
| 49d890d514 | |||
| 0292bc2728 | |||
| 5c21120977 |
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support pending-change and focused file reviews while applying checklist rules, shared component reuse checks, and React component structure guidance from how-to-write-component."
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
@ -16,10 +16,12 @@ Stick to the checklist below for every applicable file and mode.
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
|
||||
When reviewing React/TypeScript components, also apply the repo-local `how-to-write-component` skill as the component architecture checklist. In particular, check ownership boundaries, props and API types, query/mutation usage, navigation choices, effect usage, unnecessary wrappers, and unnecessary memoization.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
1. Open the relevant component/module. Gather lines that relate to shared base/dify-ui component reuse, class names, styling/CSS imports, file size and component boundaries, i18n keys, behavior-sensitive UI interactions, React Flow hooks, and prop memoization.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
@ -70,4 +72,3 @@ If you use Template A (i.e., there are issues to fix) and at least one issue req
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
|
||||
|
||||
@ -13,3 +13,29 @@ Node components are also used when creating a RAG Pipe from a template, but in t
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
|
||||
## Locale keys must be complete
|
||||
|
||||
IsUrgent: True
|
||||
Category: Business Logic
|
||||
|
||||
### Description
|
||||
|
||||
When adding or changing user-facing i18n keys, ensure every supported locale file has the same key set as `web/i18n/en-US/`. Do not add only English keys or only a partial subset of locales; `pnpm i18n:check --file <name>` should pass for the touched translation file.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Add matching keys to every existing supported locale file for the touched translation namespace, keeping key paths aligned with the English entry.
|
||||
|
||||
## Preserve behavior-sensitive interactions
|
||||
|
||||
IsUrgent: True
|
||||
Category: Business Logic
|
||||
|
||||
### Description
|
||||
|
||||
When changing existing navigation, sidebar, dropdown, webapp list, or app-switching UI, compare behavior against the existing implementation before approving the change. Watch for regressions in expand/collapse arrows, hover persistence, pin/delete controls, routing, keyboard/focus handling, and open-state ownership.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Reuse or extend the existing component when it already owns the interaction logic. If a refactor is needed, preserve the old interaction contract and add or update focused tests for the changed behavior.
|
||||
|
||||
@ -7,12 +7,12 @@ Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
Ensure conditional CSS and multi-line class composition are handled via the shared `cn` helper instead of custom ternaries, string concatenation, array `.join(' ')`, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
|
||||
@ -25,7 +25,34 @@ Category: Code Quality
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
## CSS files must be scoped
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
When CSS is truly necessary, use component-scoped `*.module.css`. Do not add component-level CSS through plain `.css` files, and do not import component CSS from `globals.css`; both patterns risk style leakage across the app.
|
||||
|
||||
## Split oversized components cautiously
|
||||
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
When a frontend file grows large or mixes multiple responsibilities, suggest splitting it into focused components, hooks, or utilities. Prefer shallow local structure that matches existing repo patterns, such as a sibling `components/` folder, and avoid deep folder hierarchies unless the surrounding code already uses them.
|
||||
|
||||
## Reuse base and dify-ui components before hand-rolling UI
|
||||
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Before approving new or modified frontend UI, check whether the code manually recreates behavior or styling already owned by `@langgenius/dify-ui/*` or `web/app/components/base/*`. Common examples include `Button`, `Input`, `ToggleGroup`, `Popover`, `DropdownMenu`, `AlertDialog`, `Switch`, `Avatar`, `ScrollArea`, `toast`, and existing feature components. Prefer composing existing primitives instead of duplicating borders, focus states, disabled states, segmented controls, inputs, overlays, or buttons.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Replace hand-written UI chrome with the nearest shared primitive, keeping feature-specific layout, state ownership, labels, and workflow behavior local.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
|
||||
@ -36,9 +63,11 @@ When writing components, always place the incoming `className` prop after the co
|
||||
Example:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
|
||||
@ -43,3 +43,14 @@ const config = useMemo(() => ({
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
|
||||
## Custom SVG icon generation
|
||||
|
||||
IsUrgent: False
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
New custom SVG icons should be added to `packages/iconify-collections/assets/...`, generated with `pnpm --filter @dify/iconify-collections generate`, checked with `pnpm --filter @dify/iconify-collections check:dimensions`, and consumed through Tailwind `i-custom-*` classes. Do not add new generated React icon components or JSON files under `web/app/components/base/icons/src/...` for new custom SVG icons.
|
||||
|
||||
When reviewing generated `packages/iconify-collections/custom-*/icons.json` diffs, verify unrelated existing icons did not lose or change intrinsic `width` / `height`.
|
||||
33
.agents/skills/karpathy-guidelines/SKILL.md
Normal file
33
.agents/skills/karpathy-guidelines/SKILL.md
Normal file
@ -0,0 +1,33 @@
|
||||
---
|
||||
name: karpathy-guidelines
|
||||
description: Lightweight coding guardrails for making focused, simple, and verifiable changes in this repo. Use for all coding work.
|
||||
---
|
||||
|
||||
# Karpathy Guidelines
|
||||
|
||||
Use this skill whenever you touch code in this repository.
|
||||
|
||||
## Principles
|
||||
|
||||
- Keep the change small and directly tied to the user request.
|
||||
- Prefer the simplest implementation that fits the existing codebase.
|
||||
- Read the nearby code first, then match its patterns.
|
||||
- Avoid unrelated refactors, broad rewrites, or style churn.
|
||||
- Preserve existing behavior unless the user explicitly asked to change it.
|
||||
- Treat regressions as a signal to narrow the change, not to add workaround layers.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Inspect the current implementation and tests around the change.
|
||||
2. Make the smallest coherent edit.
|
||||
3. Add or update focused tests when the behavior changes or the risk is non-trivial.
|
||||
4. Run the narrowest relevant verification first.
|
||||
5. Report exactly what was verified and anything left unverified.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does this change solve the stated problem without expanding scope?
|
||||
- Did it preserve existing route/component/data-flow semantics?
|
||||
- Are new abstractions justified by real complexity?
|
||||
- Are tests focused on the behavior that could regress?
|
||||
- Are unrelated files and generated artifacts left alone?
|
||||
@ -63,8 +63,8 @@ jobs:
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base "$GITHUB_WORKSPACE/base_report.json" \
|
||||
< "$GITHUB_WORKSPACE/pr_report.json")"
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -65,9 +65,6 @@ jobs:
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
# Keep fork-PR comments correct while the trusted workflow_run job is
|
||||
# still using the default-branch renderer, which resolves --base from api/.
|
||||
cp /tmp/pyrefly_report_base.json api/base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
@ -80,7 +77,6 @@ jobs:
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
api/base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
|
||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -47,10 +47,6 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Response Contract Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
|
||||
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: make type-check-core
|
||||
|
||||
4
MOCKS_TO_REMOVE_BEFORE_RELEASE.md
Normal file
4
MOCKS_TO_REMOVE_BEFORE_RELEASE.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Mocks to Remove Before Release
|
||||
|
||||
- `emptyAppList=true`: frontend URL preview flag for forcing the `/apps` page into the first-empty state. Remove the parser and rendering override before release.
|
||||
- `emptyDataList=true`: frontend URL preview flag for forcing the `/datasets` page into the first-empty state. Remove the parser and rendering override before release.
|
||||
11
Makefile
11
Makefile
@ -75,19 +75,13 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev ruff format ./api
|
||||
@uv run --project api --dev ruff check --fix ./api
|
||||
@$(MAKE) api-contract-lint
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
api-contract-lint:
|
||||
@echo "🔎 Linting Flask response contracts..."
|
||||
@uv run --project api --dev python api/dev/lint_response_contracts.py
|
||||
@echo "✅ Response contract lint complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@ -197,7 +191,6 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@ -211,4 +204,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
|
||||
@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
@ -768,7 +767,6 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -195,7 +195,6 @@ Before opening a PR / submitting:
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
|
||||
@ -49,7 +49,6 @@ class AgentBackendModelConfig(BaseModel):
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@ -139,7 +138,6 @@ class AgentBackendRunRequestBuilder:
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -11,7 +11,6 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from configs.extra.agent_backend_config import AgentBackendConfig
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
@ -6,7 +5,6 @@ from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
AgentBackendConfig,
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AgentBackendConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the Agent backend runtime integration.
|
||||
"""
|
||||
|
||||
AGENT_BACKEND_BASE_URL: str | None = Field(
|
||||
description="Base URL for the Dify Agent backend service.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_USE_FAKE: bool = Field(
|
||||
description="Use the deterministic in-process fake Agent backend client.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching tenant plugin model providers in Redis",
|
||||
default=60 * 60 * 24,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||
default=50 * 1024 * 1024,
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -71,24 +70,6 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, override, runtime_checkable
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
@ -6,7 +6,7 @@ import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final, override
|
||||
from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
|
||||
@ -36,24 +36,6 @@ class FileInfo(BaseModel):
|
||||
size: int
|
||||
|
||||
|
||||
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
if isinstance(query_string, bytes):
|
||||
raw_query = query_string.decode()
|
||||
else:
|
||||
raw_query = query_string
|
||||
if not raw_query:
|
||||
return decoded_url
|
||||
|
||||
if decoded_url.endswith(("?", "&")):
|
||||
separator = ""
|
||||
elif urllib.parse.urlsplit(decoded_url).query:
|
||||
separator = "&"
|
||||
else:
|
||||
separator = "?"
|
||||
return f"{decoded_url}{separator}{raw_query}"
|
||||
|
||||
|
||||
def guess_file_info_from_response(response: httpx.Response):
|
||||
url = str(response.url)
|
||||
# Try to extract filename from URL
|
||||
|
||||
@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
|
||||
@ -269,12 +269,12 @@ class AnnotationApi(Resource):
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return "", 204
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return result, 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(str(app_id))
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
|
||||
@ -633,7 +633,7 @@ class AppApi(Resource):
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
|
||||
@ -29,6 +29,9 @@ from fields.conversation_fields import (
|
||||
from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ResultResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@ -74,6 +77,7 @@ register_schema_models(
|
||||
ConversationMessageDetailResponse,
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
)
|
||||
@ -190,7 +194,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
@ -343,7 +347,7 @@ class ChatConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
|
||||
@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
|
||||
@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -31,10 +30,26 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
file_info_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
@ -46,6 +61,58 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
|
||||
|
||||
content_fields_copy = content_fields.copy()
|
||||
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
|
||||
content_model = get_or_create_model("DatasetContent", content_fields_copy)
|
||||
|
||||
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
|
||||
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
|
||||
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
|
||||
|
||||
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
@ -141,165 +208,9 @@ class ConsoleDatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetListItemResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str]
|
||||
|
||||
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetListItemResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetQueryFileInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetQueryContentResponse(ResponseModel):
|
||||
content_type: str
|
||||
content: str
|
||||
file_info: DatasetQueryFileInfoResponse | None = None
|
||||
|
||||
|
||||
class DatasetQueryDetailResponse(ResponseModel):
|
||||
id: str
|
||||
queries: list[DatasetQueryContentResponse]
|
||||
source: str
|
||||
source_app_id: str | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DatasetQueryListResponse(ResponseModel):
|
||||
data: list[DatasetQueryDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class RelatedAppResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
icon_url: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_icon_url(self) -> "RelatedAppResponse":
|
||||
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
|
||||
return self
|
||||
|
||||
|
||||
class RelatedAppListResponse(ResponseModel):
|
||||
data: list[RelatedAppResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DocumentStatusResponse(ResponseModel):
|
||||
id: str
|
||||
indexing_status: str
|
||||
processing_started_at: int | None
|
||||
parsing_completed_at: int | None
|
||||
cleaning_completed_at: int | None
|
||||
splitting_completed_at: int | None
|
||||
completed_at: int | None
|
||||
paused_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
@field_validator(
|
||||
"processing_started_at",
|
||||
"parsing_completed_at",
|
||||
"cleaning_completed_at",
|
||||
"splitting_completed_at",
|
||||
"completed_at",
|
||||
"paused_at",
|
||||
"stopped_at",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentStatusListResponse(ResponseModel):
|
||||
data: list[DocumentStatusResponse]
|
||||
|
||||
|
||||
class ErrorDocsResponse(DocumentStatusListResponse):
|
||||
total: int
|
||||
|
||||
|
||||
class IndexingEstimatePreviewItemResponse(ResponseModel):
|
||||
content: str
|
||||
child_chunks: list[str] | None = None
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class IndexingEstimateResponse(ResponseModel):
|
||||
total_segments: int
|
||||
preview: list[IndexingEstimatePreviewItemResponse]
|
||||
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
|
||||
|
||||
|
||||
class RetrievalSettingResponse(ResponseModel):
|
||||
retrieval_method: list[str]
|
||||
|
||||
|
||||
class PartialMemberListResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
class AutoDisableLogsResponse(ResponseModel):
|
||||
document_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetQueryListResponse,
|
||||
IndexingEstimateResponse,
|
||||
RelatedAppListResponse,
|
||||
DocumentStatusListResponse,
|
||||
ErrorDocsResponse,
|
||||
RetrievalSettingResponse,
|
||||
PartialMemberListResponse,
|
||||
AutoDisableLogsResponse,
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@ -382,8 +293,17 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
class DatasetListApi(Resource):
|
||||
@console_ns.doc("get_datasets")
|
||||
@console_ns.doc(description="Get list of datasets")
|
||||
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
|
||||
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"ids": "Filter by dataset IDs (list)",
|
||||
"keyword": "Search keyword",
|
||||
"tag_ids": "Filter by tag IDs (list)",
|
||||
"include_all": "Include all datasets (default: false)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Datasets retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -422,7 +342,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
|
||||
partial_members_map: dict[str, list[str]] = {}
|
||||
if dataset_ids:
|
||||
@ -459,12 +379,12 @@ class DatasetListApi(Resource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -493,7 +413,7 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -501,11 +421,7 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("get_dataset")
|
||||
@console_ns.doc(description="Get dataset details")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -521,7 +437,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@ -554,11 +470,7 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -594,7 +506,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
@ -623,7 +535,7 @@ class DatasetApi(Resource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@ -655,11 +567,7 @@ class DatasetQueryApi(Resource):
|
||||
@console_ns.doc("get_dataset_queries")
|
||||
@console_ns.doc(description="Get dataset query history")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Query history retrieved successfully",
|
||||
console_ns.models[DatasetQueryListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -681,24 +589,20 @@ class DatasetQueryApi(Resource):
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
"data": dataset_queries,
|
||||
"data": marshal(dataset_queries, dataset_query_detail_model),
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return dump_response(DatasetQueryListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/indexing-estimate")
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
@console_ns.doc("estimate_dataset_indexing")
|
||||
@console_ns.doc(description="Estimate dataset indexing cost")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing estimate calculated successfully",
|
||||
console_ns.models[IndexingEstimateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -795,14 +699,11 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@console_ns.doc("get_dataset_related_apps")
|
||||
@console_ns.doc(description="Get applications related to dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Related apps retrieved successfully",
|
||||
console_ns.models[RelatedAppListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list_model)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -823,7 +724,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
|
||||
@ -831,11 +732,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@console_ns.doc("get_dataset_indexing_status")
|
||||
@console_ns.doc(description="Get dataset indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing status retrieved successfully",
|
||||
console_ns.models[DocumentStatusListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -881,8 +778,9 @@ class DatasetIndexingStatusApi(Resource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys")
|
||||
@ -975,7 +873,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||
@ -1009,18 +907,13 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting")
|
||||
@console_ns.doc(description="Get dataset retrieval settings")
|
||||
@console_ns.response(
|
||||
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
|
||||
)
|
||||
@console_ns.response(200, "Retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
|
||||
)
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
@ -1028,19 +921,12 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting_mock")
|
||||
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@console_ns.doc(params={"vector_type": "Vector store type"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Mock retrieval settings retrieved successfully",
|
||||
console_ns.models[RetrievalSettingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
|
||||
)
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
@ -1048,7 +934,7 @@ class DatasetErrorDocs(Resource):
|
||||
@console_ns.doc("get_dataset_error_docs")
|
||||
@console_ns.doc(description="Get dataset error documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
|
||||
@console_ns.response(200, "Error documents retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1060,7 +946,7 @@ class DatasetErrorDocs(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
@ -1068,11 +954,7 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@console_ns.doc("get_dataset_permission_users")
|
||||
@console_ns.doc(description="Get dataset permission user list")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Permission users retrieved successfully",
|
||||
console_ns.models[PartialMemberListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Permission users retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -1091,7 +973,9 @@ class DatasetPermissionUserListApi(Resource):
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
|
||||
return {
|
||||
"data": partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
|
||||
@ -1099,11 +983,7 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
@console_ns.doc("get_dataset_auto_disable_logs")
|
||||
@console_ns.doc(description="Get dataset auto disable logs")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Auto disable logs retrieved successfully",
|
||||
console_ns.models[AutoDisableLogsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Auto disable logs retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1113,4 +993,4 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
@ -966,7 +966,7 @@ class DocumentApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot pause completed document.")
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Document is not in paused status.")
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
|
||||
@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segments(segment_ids, document, dataset)
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
|
||||
@ -1,18 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, marshal_with
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -26,12 +22,7 @@ from services.metadata_service import MetadataService
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -40,7 +31,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -53,22 +44,18 @@ class DatasetMetadataCreateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
return metadata, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -77,7 +64,7 @@ class DatasetMetadataApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -92,7 +79,7 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -109,8 +96,7 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
@ -119,14 +105,9 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -135,7 +116,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -149,8 +130,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -160,10 +140,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
@console_ns.response(
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -176,5 +153,4 @@ class DocumentMetadataEditApi(Resource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Frontend callers only await success and invalidate caches; no response body is consumed.
|
||||
return "", 204
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
||||
@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
|
||||
@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource):
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -28,32 +28,7 @@ class FeatureApi(Resource):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
).model_dump()
|
||||
payload.pop("vector_space", None)
|
||||
return payload
|
||||
|
||||
|
||||
@console_ns.route("/features/vector-space")
|
||||
class FeatureVectorSpaceApi(Resource):
|
||||
@console_ns.doc("get_tenant_feature_vector_space")
|
||||
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[LimitationModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -33,7 +34,7 @@ class GetRemoteFileInfo(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
|
||||
@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
|
||||
@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
|
||||
@ -15,7 +15,6 @@ from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal, override
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -76,13 +76,11 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -1,17 +1,13 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_enum_models,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@ -21,10 +17,9 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -124,21 +119,6 @@ class TagUnbindingPayload(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class KnowledgeTagResponse(ResponseModel):
|
||||
model_config = ConfigDict(coerce_numbers_to_str=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
# TODO: The public Service API docs expose binding_count as string|null.
|
||||
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
|
||||
binding_count: str | None = None
|
||||
|
||||
|
||||
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
@ -147,29 +127,6 @@ class DatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
# todo: duplicate code, but the partial_member_list has different nullability
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetBoundTagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class DatasetBoundTagListResponse(ResponseModel):
|
||||
data: list[DatasetBoundTagResponse]
|
||||
total: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
@ -180,17 +137,9 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
SimpleResultResponse,
|
||||
KnowledgeTagResponse,
|
||||
KnowledgeTagListResponse,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetBoundTagListResponse,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets")
|
||||
@ -205,18 +154,9 @@ class DatasetListApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Datasets retrieved successfully",
|
||||
service_api_ns.models[DatasetListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = DatasetListQuery.model_validate(query_params)
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict())
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
@ -235,17 +175,22 @@ class DatasetListApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item["embedding_available"] = False # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -253,7 +198,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@ -265,11 +210,6 @@ class DatasetListApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset created successfully",
|
||||
service_api_ns.models[DatasetDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
@ -313,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return dump_response(DatasetDetailResponse, dataset), 200
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -331,11 +271,6 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -345,7 +280,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -377,13 +312,7 @@ class DatasetApi(DatasetApiResource):
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return (
|
||||
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
|
||||
mode="json",
|
||||
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
|
||||
),
|
||||
200,
|
||||
)
|
||||
return data, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@ -397,11 +326,6 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -452,7 +376,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -465,7 +389,7 @@ class DatasetApi(DatasetApiResource):
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
|
||||
return result_data, 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@ -578,7 +502,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
@ -591,18 +515,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[KnowledgeTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -614,11 +534,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag created successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -628,10 +543,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
|
||||
)
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -644,11 +558,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag updated successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -660,10 +569,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
|
||||
)
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@ -748,11 +656,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[DatasetBoundTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
@ -760,4 +663,5 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
@ -1,19 +1,15 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -31,13 +27,7 @@ register_schema_models(
|
||||
DocumentMetadataOperation,
|
||||
MetadataOperationData,
|
||||
)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -53,9 +43,6 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
@ -68,7 +55,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
return marshal(metadata, dataset_metadata_fields), 201
|
||||
|
||||
@service_api_ns.doc("get_dataset_metadata")
|
||||
@service_api_ns.doc(description="Get all metadata for a dataset")
|
||||
@ -80,17 +67,13 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all metadata for a dataset."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -106,9 +89,6 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
@ -122,7 +102,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@service_api_ns.doc(description="Delete metadata")
|
||||
@ -134,7 +114,6 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(204, "Metadata deleted successfully")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Delete metadata."""
|
||||
@ -159,15 +138,10 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -183,7 +157,9 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
|
||||
200,
|
||||
"Action completed successfully",
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
|
||||
@ -199,7 +175,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -218,7 +194,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Documents metadata updated successfully",
|
||||
service_api_ns.models[DatasetMetadataActionResponse.__name__],
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
@ -233,4 +209,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
|
||||
@service_api_ns.route("/")
|
||||
class IndexApi(Resource):
|
||||
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
|
||||
def get(self) -> dict[str, str]:
|
||||
def get(self):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
|
||||
@ -136,7 +136,7 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
@ -58,7 +59,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
Raises:
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
|
||||
@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource):
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return "", 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.file import file_manager
|
||||
@ -67,7 +66,6 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -52,7 +51,6 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
@ -23,7 +23,6 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@ -35,7 +34,6 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
|
||||
@ -55,7 +55,6 @@ from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
@ -146,15 +145,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
conversation = None
|
||||
else:
|
||||
raise
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -20,7 +20,6 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -60,7 +59,6 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -78,7 +76,6 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -110,7 +107,6 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,7 +16,6 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -38,7 +37,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -56,7 +54,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -88,7 +85,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,7 +16,6 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -38,7 +37,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -56,7 +54,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -88,7 +85,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,7 +16,6 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,7 +36,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -55,7 +53,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -86,7 +83,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol, override
|
||||
from typing import Any, Protocol
|
||||
|
||||
from graphon.enums import NodeType
|
||||
|
||||
@ -29,6 +29,5 @@ class DraftVariableSaverFactory(Protocol):
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
@override
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -23,7 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,7 +19,6 @@ class PipelineQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,7 +19,6 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -18,7 +18,6 @@ class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -30,7 +29,6 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -42,7 +40,6 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -76,7 +73,6 @@ class WorkflowAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -32,11 +31,9 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
@override
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
@override
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
@ -65,7 +62,6 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
@ -82,7 +78,6 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
@override
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
@ -100,7 +95,6 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -8,7 +8,6 @@ scope updates that matter to chat applications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
@ -24,11 +23,9 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
return
|
||||
@ -47,6 +44,5 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, override
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
@ -83,7 +83,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
@ -93,7 +92,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
@ -134,7 +132,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
|
||||
@ -11,11 +9,9 @@ class SuspendLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
@ -23,7 +19,6 @@ class SuspendLayer(GraphEngineLayer):
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._paused = True
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
self._paused = False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar, override
|
||||
from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
@ -63,7 +63,6 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
@ -79,11 +78,9 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar, override
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -37,11 +37,9 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
@ -84,6 +82,5 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal, override
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
@ -40,19 +40,15 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
@override
|
||||
def multimodal_send_format(self) -> str:
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
@override
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
@override
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@override
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
storage_key = self._resolve_storage_key(file=file)
|
||||
data = storage.load(storage_key, stream=False)
|
||||
@ -60,7 +56,6 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
raise ValueError(f"file {storage_key} is not a bytes object")
|
||||
return data
|
||||
|
||||
@override
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
@ -91,7 +86,6 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
@ -107,12 +101,10 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
query["as_attachment"] = "true"
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
@override
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
@override
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -12,7 +12,7 @@ state.
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union, override
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
@ -98,14 +98,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
match event:
|
||||
case GraphRunStartedEvent():
|
||||
@ -133,7 +131,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
case NodeRunPauseRequestedEvent():
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -24,10 +22,8 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return self.icon
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -19,14 +19,12 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,6 +67,5 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -19,7 +17,6 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -68,6 +67,5 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -19,7 +17,6 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -47,6 +47,5 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -23,7 +21,6 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
@ -1889,7 +1889,6 @@ class ProviderConfigurations(BaseModel):
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict, override
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -29,7 +29,6 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -51,7 +50,6 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
@override
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -7,12 +6,10 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class JavascriptCodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.JAVASCRIPT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
@ -10,7 +10,6 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
Transform response to dict
|
||||
@ -20,7 +19,6 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
@ -36,7 +34,6 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
import jinja2
|
||||
@ -64,7 +61,6 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return runner_script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_preload_script(cls) -> str:
|
||||
preload_script = dedent("""
|
||||
import jinja2
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -7,12 +6,10 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class Python3CodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.PYTHON3
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -47,7 +47,6 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
@ -62,7 +61,6 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
|
||||
@ -43,16 +43,13 @@ request_error = httpx.RequestError
|
||||
max_retries_exceeded_error = MaxRetriesExceededError
|
||||
|
||||
|
||||
def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]:
|
||||
"""Build per-scheme proxy transports with the same TLS policy as the SSRF client."""
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTP_URL,
|
||||
verify=verify,
|
||||
),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
verify=verify,
|
||||
),
|
||||
}
|
||||
|
||||
@ -67,7 +64,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client:
|
||||
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
return httpx.Client(
|
||||
mounts=_create_proxy_mounts(verify=verify),
|
||||
mounts=_create_proxy_mounts(),
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
import flask
|
||||
|
||||
@ -16,7 +15,6 @@ class TraceContextFilter(logging.Filter):
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
@ -56,7 +54,6 @@ class IdentityContextFilter(logging.Filter):
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, NotRequired, TypedDict, override
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
@ -58,7 +58,6 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
@override
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
|
||||
@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -159,7 +159,6 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
@override
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
@ -169,7 +168,6 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
@override
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
@ -182,7 +180,6 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
@override
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol, override
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@ -159,7 +159,6 @@ class ClientSession(
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@override
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
@ -327,7 +326,6 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
@ -353,7 +351,6 @@ class ClientSession(
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
@override
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@ -361,7 +358,6 @@ class ClientSession(
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
@override
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -25,7 +25,6 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -44,7 +43,6 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -61,7 +59,6 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
@ -8,7 +8,6 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -29,7 +28,6 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -51,7 +49,6 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
@ -9,7 +9,6 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -20,7 +19,6 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -38,7 +36,6 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -12,7 +11,6 @@ class PluginDaemonError(Exception):
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Literal, cast, overload, override
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
@ -12,9 +13,9 @@ from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -100,38 +101,35 @@ class _PluginStructuredOutputModelInstance:
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope.
|
||||
|
||||
Provider discovery goes through ``PluginService`` so the plugin lifecycle
|
||||
methods and provider reads share one tenant-scoped cache owner.
|
||||
"""
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_plugin_service: type[PluginService]
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
client: PluginModelClient,
|
||||
plugin_service: type[PluginService],
|
||||
) -> None:
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
if plugin_service is None:
|
||||
raise ValueError("plugin_service is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._plugin_service = plugin_service
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
@override
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client)
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
@override
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
@ -174,7 +172,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
@override
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
@ -185,7 +182,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
@ -205,7 +201,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
@ -299,7 +294,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -363,7 +357,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -403,7 +396,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -430,7 +422,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -452,7 +443,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -474,7 +464,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -494,7 +483,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -520,7 +508,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -546,7 +533,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
@ -568,7 +554,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
@ -588,7 +573,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
language=language,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
@ -608,7 +592,6 @@ class PluginModelRuntime(ModelRuntime):
|
||||
file=file,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
@ -628,6 +611,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
|
||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
@ -118,7 +117,6 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
client=PluginModelClient(),
|
||||
plugin_service=PluginService,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, TypedDict, override
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
@ -29,7 +29,6 @@ class Jieba(BaseKeyword):
|
||||
super().__init__(dataset)
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -49,7 +48,6 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return self
|
||||
|
||||
@override
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -74,14 +72,12 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -91,7 +87,6 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
@ -127,7 +122,6 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user