mirror of
https://github.com/langgenius/dify.git
synced 2026-05-20 08:46:57 +08:00
Compare commits
175 Commits
chore/remo
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| 11c652f146 | |||
| 05408af8a1 | |||
| d3ae074456 | |||
| 0b48a7e991 | |||
| 809f513ccb | |||
| d9e90d0fa0 | |||
| d1417bbe4b | |||
| 2565637e36 | |||
| cae9923e5a | |||
| a328bbbced | |||
| 5276eb689b | |||
| 4b2badb6f2 | |||
| 34a89416f7 | |||
| a13ab76002 | |||
| b04b4449db | |||
| 674cdc3521 | |||
| 2031d31ee8 | |||
| 04d62867af | |||
| 7f392b6950 | |||
| b0a3399774 | |||
| 2d5186fb28 | |||
| 06f076e0ff | |||
| 5b79f7e99d | |||
| 1cee1a25b6 | |||
| c0f237bf35 | |||
| 75d7fc0526 | |||
| c057b5c5ff | |||
| 5468c4ec96 | |||
| f4c02e4c6b | |||
| 9dc95eeb20 | |||
| 76bba64b79 | |||
| 59e96fbb2a | |||
| 06ea0f7ac2 | |||
| 730a0bef9e | |||
| 2eb37caf2e | |||
| 7e8147295b | |||
| c07686928a | |||
| d1238180ed | |||
| 969760364d | |||
| ceabfeb3a7 | |||
| c407f40e0d | |||
| 28818f2e2a | |||
| e2c52c9b0f | |||
| 1925d58369 | |||
| b79fc5d6b4 | |||
| 6649e4025e | |||
| b96f372f45 | |||
| 127fbf2c9a | |||
| 3c70d28064 | |||
| cd4d6f8a22 | |||
| 9d0906c684 | |||
| 41b6f894c0 | |||
| e7e6fe8813 | |||
| c0bdd6792f | |||
| 27b084c4d4 | |||
| 3f7a68fc77 | |||
| a252fbddfa | |||
| ff02636a4b | |||
| 63946d829e | |||
| cdcfd2ef2c | |||
| b04a3851cc | |||
| b41338cd08 | |||
| 28153df4d3 | |||
| 3bc3386535 | |||
| 7654f14241 | |||
| 194b54bae4 | |||
| 0e16d36edb | |||
| 432a6412a3 | |||
| 55d05fe52d | |||
| 0d500e6965 | |||
| 5798610f27 | |||
| a35b28dbef | |||
| 1a4288c811 | |||
| 9dc32f2318 | |||
| 7210f856c9 | |||
| ebcc1200a3 | |||
| e660d7af38 | |||
| d9ccfcbc6e | |||
| a9bcec013f | |||
| aeb7687e2c | |||
| 9355d36718 | |||
| a03ee828a3 | |||
| 7066372892 | |||
| 55f95dbc36 | |||
| 8b40de3c4e | |||
| af4b9bfa8f | |||
| b9e3130388 | |||
| 12d33652b6 | |||
| fe8cf2aff4 | |||
| d1d190374d | |||
| e1be4e6aa8 | |||
| 301a470e7a | |||
| 91251ad5a5 | |||
| 3f6644a615 | |||
| 5edc682c4a | |||
| 13c00ecfc4 | |||
| 9d545144ce | |||
| 2afa39cdcb | |||
| bb1c883be4 | |||
| 03861bcee3 | |||
| c34fc429ae | |||
| d110112863 | |||
| 934a20e745 | |||
| 7e56a244a8 | |||
| 6facd9360c | |||
| a18d7f51eb | |||
| 680ef077ae | |||
| c26be9d3f4 | |||
| 51a8f79d67 | |||
| bb73776339 | |||
| 9424bf60b0 | |||
| cbedcd2882 | |||
| 1a93af5cd0 | |||
| cd90d7ffc1 | |||
| 4bb987eca3 | |||
| 4fd4615c56 | |||
| c7d30bf09a | |||
| 59dab7deac | |||
| a60cb3b800 | |||
| 6164408da1 | |||
| 7fc40e6c9e | |||
| d625ac0bf1 | |||
| 1082f488a1 | |||
| f1c4c1a5ff | |||
| dd1cdbbd41 | |||
| 74a04afe27 | |||
| b108ea42f6 | |||
| 1aa6188b7d | |||
| bd0d10ac5c | |||
| 2162ea6a68 | |||
| 153064bbd4 | |||
| a643b05368 | |||
| 279b66bc7f | |||
| e134c1e0d5 | |||
| 9127209dd5 | |||
| a2ee151e48 | |||
| 9e3e616391 | |||
| 837b5cad86 | |||
| 1a011dc14a | |||
| bf117dd0c8 | |||
| 1e6dc62470 | |||
| 0b70eec695 | |||
| e8dc706414 | |||
| 9a2bea9287 | |||
| b95e6f6a7a | |||
| b99ba74aa4 | |||
| 7b5c371b9d | |||
| c67ce6f66d | |||
| e48d7bb097 | |||
| 24ea21db25 | |||
| 8581a68174 | |||
| f720a3bed2 | |||
| 4a56763d2f | |||
| 861f73267c | |||
| 1efd365b62 | |||
| 65c36a51ef | |||
| 19476109da | |||
| f3eb3ab4dd | |||
| 2c9e30426d | |||
| 2bb1f0906b | |||
| d5ad6aedc0 | |||
| 5ebeb34feb | |||
| c5ac191a79 | |||
| 140ad6ba4e | |||
| e03eb3a76c | |||
| 38a419d073 | |||
| c74cbb68da | |||
| 271019006e | |||
| 19bf36a716 | |||
| 48d27e250b | |||
| d06b5529b3 | |||
| 8132c444dc | |||
| cb0356e9d7 | |||
| 4d80892d7b | |||
| af754f497a |
@ -63,7 +63,7 @@ pnpm analyze-component <path> --json
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Complex state logic in component
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
@ -85,7 +85,7 @@ export const useModelConfig = (appId: string) => {
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||
return <div>...</div>
|
||||
}
|
||||
@ -189,8 +189,6 @@ const Template = useMemo(() => {
|
||||
|
||||
**Dify Convention**:
|
||||
- This skill is for component decomposition, not query/mutation design.
|
||||
- When refactoring data fetching, follow `web/AGENTS.md`.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
|
||||
|
||||
|
||||
@ -60,8 +60,10 @@ const Template = useMemo(() => {
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
import type { ComponentType } from 'react'
|
||||
|
||||
// Define lookup table outside component
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, ComponentType<TemplateProps>>> = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
|
||||
@ -65,10 +65,10 @@ interface ConfigurationHeaderProps {
|
||||
onPublish: () => void
|
||||
}
|
||||
|
||||
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||
function ConfigurationHeader({
|
||||
isAdvancedMode,
|
||||
onPublish,
|
||||
}) => {
|
||||
}: ConfigurationHeaderProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
@ -136,7 +136,7 @@ const AppInfo = () => {
|
||||
}
|
||||
|
||||
// ✅ After: Separate view components
|
||||
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
function AppInfoExpanded({ appDetail, onAction }: AppInfoViewProps) {
|
||||
return (
|
||||
<div className="expanded">
|
||||
{/* Clean, focused expanded view */}
|
||||
@ -144,7 +144,7 @@ const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
function AppInfoCollapsed({ appDetail, onAction }: AppInfoViewProps) {
|
||||
return (
|
||||
<div className="collapsed">
|
||||
{/* Clean, focused collapsed view */}
|
||||
@ -203,12 +203,12 @@ interface AppInfoModalsProps {
|
||||
onSuccess: () => void
|
||||
}
|
||||
|
||||
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||
function AppInfoModals({
|
||||
appDetail,
|
||||
activeModal,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}) => {
|
||||
}: AppInfoModalsProps) {
|
||||
const handleEdit = async (data) => { /* logic */ }
|
||||
const handleDuplicate = async (data) => { /* logic */ }
|
||||
const handleDelete = async () => { /* logic */ }
|
||||
@ -296,7 +296,7 @@ interface OperationItemProps {
|
||||
onAction: (id: string) => void
|
||||
}
|
||||
|
||||
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||
function OperationItem({ operation, onAction }: OperationItemProps) {
|
||||
return (
|
||||
<div className="operation-item">
|
||||
<span className="icon">{operation.icon}</span>
|
||||
@ -435,7 +435,7 @@ interface ChildProps {
|
||||
onSubmit: () => void
|
||||
}
|
||||
|
||||
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||
function Child({ value, onChange, onSubmit }: ChildProps) {
|
||||
return (
|
||||
<div>
|
||||
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||
|
||||
@ -112,13 +112,13 @@ export const useModelConfig = ({
|
||||
|
||||
```typescript
|
||||
// Before: 50+ lines of state management
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
// ... lots of related state and effects
|
||||
}
|
||||
|
||||
// After: Clean component
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
@ -159,8 +159,6 @@ const Configuration: FC = () => {
|
||||
|
||||
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
|
||||
|
||||
- Follow `web/AGENTS.md` first.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
5. Re-check official Playwright or Cucumber docs with the available documentation tools before introducing a new framework pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
|
||||
@ -9,18 +9,18 @@ Category: Performance
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
|
||||
## Complex prop memoization
|
||||
## Complex prop stability
|
||||
|
||||
IsUrgent: True
|
||||
IsUrgent: False
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
|
||||
Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
|
||||
Wrong:
|
||||
Risky:
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
@ -31,7 +31,7 @@ Wrong:
|
||||
/>
|
||||
```
|
||||
|
||||
Right:
|
||||
Better when stable identity matters:
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
|
||||
- `web/contract/console/*.ts`
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
|
||||
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
|
||||
|
||||
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.
|
||||
@ -1,4 +0,0 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
@ -1,129 +0,0 @@
|
||||
# Contract Patterns
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Intent
|
||||
- Minimal structure
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
|
||||
- Use `base.route({...}).output(type<...>())` as the baseline.
|
||||
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
|
||||
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
2. Register contract in `web/contract/router.ts`.
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utilities.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- Input structure: always use `{ params, query?, body? }`.
|
||||
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
- Path params: use `{paramName}` in the path and match it in the `params` object.
|
||||
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
|
||||
- No barrel files: import directly from specific files.
|
||||
- Types: import from `@/types/` and use the `type<T>()` helper.
|
||||
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
@ -1,172 +0,0 @@
|
||||
# Runtime Rules
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
- Legacy migration
|
||||
|
||||
## Conditional Queries
|
||||
|
||||
Prefer contract-shaped `queryOptions(...)`.
|
||||
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
|
||||
Use `enabled` only for extra business gating after the input itself is already valid.
|
||||
|
||||
```typescript
|
||||
import { skipToken, useQuery } from '@tanstack/react-query'
|
||||
|
||||
// Disable the query by skipping input construction.
|
||||
function useAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: appId
|
||||
? { params: { appId } }
|
||||
: skipToken,
|
||||
}))
|
||||
}
|
||||
|
||||
// Avoid runtime-only guards that bypass type checking.
|
||||
function useBadAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: { params: { appId: appId! } },
|
||||
enabled: !!appId,
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
|
||||
- `.key()` for namespace or prefix invalidation
|
||||
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
|
||||
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
|
||||
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => toast.success('...'),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
mutate({ appId, mode }, {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Key API Guide
|
||||
|
||||
- `.key(...)`
|
||||
- Use for partial matching operations.
|
||||
- Prefer it for invalidation, refetch, and cancel patterns.
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`
|
||||
- Use for a specific query's full key.
|
||||
- Prefer it for exact cache addressing and direct reads or writes.
|
||||
- `.mutationKey(...)`
|
||||
- Use for a specific mutation's full key.
|
||||
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
|
||||
|
||||
## `mutate` vs `mutateAsync`
|
||||
|
||||
Prefer `mutate` by default.
|
||||
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
|
||||
|
||||
Rules:
|
||||
|
||||
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
|
||||
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
|
||||
- Do not use `mutateAsync` when callbacks already express the flow clearly.
|
||||
|
||||
```typescript
|
||||
// Default case.
|
||||
mutation.mutate(data, {
|
||||
onSuccess: result => router.push(result.url),
|
||||
})
|
||||
|
||||
// Promise semantics are required.
|
||||
try {
|
||||
const order = await createOrder.mutateAsync(orderData)
|
||||
await confirmPayment.mutateAsync({ orderId: order.id, token })
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Migration
|
||||
|
||||
When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
@ -5,7 +5,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
This skill enables Codex to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
@ -24,35 +24,27 @@ Apply this skill when the user:
|
||||
**Do NOT apply** when:
|
||||
|
||||
- User is asking about backend/API tests (Python/pytest)
|
||||
- User is asking about E2E tests (Playwright/Cypress)
|
||||
- User is asking about E2E tests (Cucumber + Playwright under `e2e/`)
|
||||
- User is only asking conceptual questions without code context
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Tech Stack
|
||||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Vitest | 4.0.16 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| jsdom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
### Key Commands
|
||||
|
||||
Run these commands from `web/`. From the repository root, prefix them with `pnpm -C web`.
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch
|
||||
pnpm test --watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test:coverage
|
||||
pnpm test --coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
@ -228,7 +220,10 @@ Every test should clearly separate:
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Use semantic queries (getByRole, getByLabelText)
|
||||
- Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`)
|
||||
- Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`.
|
||||
- Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment.
|
||||
- Do not assert decorative icons by test id. Assert the named control that contains them, or mark decorative icons `aria-hidden`.
|
||||
- Avoid testing internal state directly
|
||||
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ See [Zustand Store Testing](#zustand-store-testing) section for full details.
|
||||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `zustand`, clipboard, FloatingPortal, Monaco, localStorage`) |
|
||||
| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
|
||||
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||
@ -216,28 +216,21 @@ describe('Component', () => {
|
||||
})
|
||||
```
|
||||
|
||||
### 5. HTTP Mocking with Nock
|
||||
### 5. HTTP and `fetch` Mocking
|
||||
|
||||
```typescript
|
||||
import nock from 'nock'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/owner/repo'
|
||||
|
||||
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST)
|
||||
.get(GITHUB_PATH)
|
||||
.delay(delayMs)
|
||||
.reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubComponent', () => {
|
||||
afterEach(() => {
|
||||
nock.cleanAll()
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should display repo info', async () => {
|
||||
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ name: 'dify', stars: 1000 }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
@ -247,7 +240,12 @@ describe('GithubComponent', () => {
|
||||
})
|
||||
|
||||
it('should handle API error', async () => {
|
||||
mockGithubApi(500, { message: 'Server error' })
|
||||
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ message: 'Server error' }), {
|
||||
status: 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
@ -258,6 +256,8 @@ describe('GithubComponent', () => {
|
||||
})
|
||||
```
|
||||
|
||||
Prefer mocking `@/service/*` modules or spying on `global.fetch` / `ky` clients with deterministic responses. Do not introduce an HTTP interception dependency such as `nock` or MSW unless it is already declared in the workspace or adding it is part of the task.
|
||||
|
||||
### 6. Context Providers
|
||||
|
||||
```typescript
|
||||
@ -332,7 +332,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
1. Don't forget to clean up nock after each test
|
||||
1. Don't leave HTTP mocks or service mock state leaking between tests
|
||||
1. Don't use `any` types in mocks without necessity
|
||||
|
||||
### Mock Decision Tree
|
||||
|
||||
@ -227,12 +227,12 @@ Failing tests compound:
|
||||
|
||||
**Fix failures immediately before proceeding.**
|
||||
|
||||
## Integration with Claude's Todo Feature
|
||||
## Integration with Codex's Todo Feature
|
||||
|
||||
When using Claude for multi-file testing:
|
||||
When using Codex for multi-file testing:
|
||||
|
||||
1. **Ask Claude to create a todo list** before starting
|
||||
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||
1. **Create a todo list** before starting
|
||||
1. **Process one file at a time**
|
||||
1. **Verify each test passes** before asking for the next
|
||||
1. **Mark todos complete** as you progress
|
||||
|
||||
|
||||
71
.agents/skills/how-to-write-component/SKILL.md
Normal file
71
.agents/skills/how-to-write-component/SKILL.md
Normal file
@ -0,0 +1,71 @@
|
||||
---
|
||||
name: how-to-write-component
|
||||
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
|
||||
---
|
||||
|
||||
# How To Write A Component
|
||||
|
||||
Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation.
|
||||
|
||||
## Core Defaults
|
||||
|
||||
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
|
||||
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind guidance.
|
||||
|
||||
## Ownership
|
||||
|
||||
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
|
||||
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
|
||||
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
|
||||
## Components, Props, And Types
|
||||
|
||||
- Type component signatures directly; do not use `FC` or `React.FC`.
|
||||
- Prefer `function` for top-level components and module helpers. Use arrow functions for local callbacks, handlers, and lambda-style APIs.
|
||||
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
|
||||
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
|
||||
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers beside the component that needs them.
|
||||
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially IDs like `appInstanceId`. Normalize framework or route params at the boundary.
|
||||
- Keep fallback and invariant checks at the lowest component that already handles that state; callers should pass raw values through instead of duplicating checks.
|
||||
|
||||
## Queries And Mutations
|
||||
|
||||
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
|
||||
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
|
||||
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
|
||||
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
|
||||
- For missing required query input, use `input: skipToken`; use `enabled` only for extra business gating after the input is valid.
|
||||
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows.
|
||||
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
|
||||
- Do not use deprecated `useInvalid` or `useReset`.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
|
||||
|
||||
## Component Boundaries
|
||||
|
||||
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
|
||||
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
|
||||
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
|
||||
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
|
||||
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
|
||||
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
|
||||
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
|
||||
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
|
||||
|
||||
## Navigation And Performance
|
||||
|
||||
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
|
||||
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.
|
||||
@ -1,5 +1,6 @@
|
||||
[run]
|
||||
omit =
|
||||
api/conftest.py
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
|
||||
60
.github/CODEOWNERS
vendored
60
.github/CODEOWNERS
vendored
@ -4,7 +4,7 @@
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
* @crazywoola @laipz8200
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
@ -85,39 +85,39 @@
|
||||
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
/api/core/plugin/ @WH-2099
|
||||
/api/services/plugin/ @WH-2099
|
||||
/api/controllers/console/workspace/plugin.py @WH-2099
|
||||
/api/controllers/inner_api/plugin/ @WH-2099
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @WH-2099
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
/api/core/trigger/ @Mairuis @Yeuoly
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
/api/services/trigger/ @Mairuis @Yeuoly
|
||||
/api/models/trigger.py @Mairuis @Yeuoly
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/controllers/trigger/ @Mairuis
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis
|
||||
/api/core/trigger/ @Mairuis
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis
|
||||
/api/services/trigger/ @Mairuis
|
||||
/api/models/trigger.py @Mairuis
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis
|
||||
/api/libs/schedule_utils.py @Mairuis
|
||||
/api/services/workflow/scheduler.py @Mairuis
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis
|
||||
|
||||
# Backend - Async Workflow
|
||||
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
/api/services/async_workflow_service.py @Mairuis
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis
|
||||
|
||||
# Backend - Billing
|
||||
/api/services/billing_service.py @hj24 @zyssyz123
|
||||
|
||||
5
.github/actions/setup-web/action.yml
vendored
5
.github/actions/setup-web/action.yml
vendored
@ -1,8 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Set up Node.js, Vite+, pnpm, and web dependencies
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5
|
||||
with:
|
||||
run_install: false
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
|
||||
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
BASE_SHA=${BASE_SHA:-}
|
||||
HEAD_SHA=${HEAD_SHA:-}
|
||||
MAIN_REF=${MAIN_REF:-origin/main}
|
||||
REMEDIATION_HINT="Changes should be made from the main branch using git cherry-pick -x."
|
||||
|
||||
error() {
|
||||
printf 'ERROR: %s\n' "$1" >&2
|
||||
}
|
||||
|
||||
if [[ -z "$BASE_SHA" || -z "$HEAD_SHA" ]]; then
|
||||
error "BASE_SHA and HEAD_SHA are required. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$BASE_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Base commit '$BASE_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$HEAD_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Head commit '$HEAD_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$MAIN_REF^{commit}" > /dev/null 2>&1; then
|
||||
error "Main ref '$MAIN_REF' is not available in the local git checkout. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
failed=0
|
||||
checked=0
|
||||
|
||||
while IFS= read -r commit_sha; do
|
||||
[[ -n "$commit_sha" ]] || continue
|
||||
|
||||
checked=$((checked + 1))
|
||||
subject=$(git log -1 --format=%s "$commit_sha")
|
||||
source_sha=$(
|
||||
git log -1 --format=%B "$commit_sha" \
|
||||
| sed -nE 's/^\(cherry picked from commit ([0-9a-fA-F]{7,64})\)$/\1/p' \
|
||||
| tail -n 1
|
||||
)
|
||||
|
||||
if [[ -z "$source_sha" ]]; then
|
||||
error "Commit $commit_sha ($subject) is missing cherry-pick provenance. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git cat-file -e "$source_sha^{commit}" 2> /dev/null; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that commit is not available locally. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git merge-base --is-ancestor "$source_sha" "$MAIN_REF"; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that source is not reachable from main ($MAIN_REF). $REMEDIATION_HINT"
|
||||
failed=1
|
||||
fi
|
||||
done < <(git rev-list --reverse "$BASE_SHA..$HEAD_SHA")
|
||||
|
||||
if [[ "$failed" -ne 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$checked" -eq 0 ]]; then
|
||||
echo "No PR commits to check."
|
||||
else
|
||||
echo "Verified $checked PR commit(s) include cherry-pick provenance from main."
|
||||
fi
|
||||
42
.github/workflows/api-tests.yml
vendored
42
.github/workflows/api-tests.yml
vendored
@ -48,10 +48,23 @@ jobs:
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
run: uv run --project api pytest api/tests/unit_tests/configs/test_env_consistency.py
|
||||
|
||||
- name: Run Unit Tests
|
||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
-n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers
|
||||
# Controller tests register Flask routes at import time, so keep them out of xdist.
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
--cov-append \
|
||||
api/tests/unit_tests/controllers
|
||||
|
||||
- name: Upload unit coverage data
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
@ -96,32 +109,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--start-middleware \
|
||||
-n auto \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
|
||||
10
.github/workflows/autofix.yml
vendored
10
.github/workflows/autofix.yml
vendored
@ -116,6 +116,16 @@ jobs:
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Generate API docs
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir ../packages/contracts/openapi --markdown-dir openapi/markdown --keep-swagger-json
|
||||
|
||||
- name: Generate frontend contracts
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: pnpm --dir packages/contracts gen-api-contract-from-openapi
|
||||
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -37,7 +37,7 @@ jobs:
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
cp envs/middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
@ -87,7 +87,7 @@ jobs:
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
cp envs/middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
|
||||
17
.github/workflows/expose_service_ports.sh
vendored
17
.github/workflows/expose_service_ports.sh
vendored
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: Hotfix Cherry-Pick Provenance
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- 'hotfix/**'
|
||||
- 'lts/**'
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- ready_for_review
|
||||
- synchronize
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: hotfix-cherry-pick-${{ github.event.pull_request.number || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-cherry-pick-provenance:
|
||||
name: Require cherry-pick provenance
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch PR base, PR head, and main
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
run: |
|
||||
git fetch --no-tags --prune origin \
|
||||
"+refs/heads/main:refs/remotes/origin/main" \
|
||||
"+refs/heads/${BASE_REF}:refs/remotes/origin/${BASE_REF}" \
|
||||
"+refs/pull/${PR_NUMBER}/head:refs/remotes/pull/${PR_NUMBER}/head"
|
||||
|
||||
- name: Load checker from main
|
||||
run: git show origin/main:.github/scripts/check-hotfix-cherry-picks.sh > "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
|
||||
- name: Check PR commits
|
||||
env:
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
MAIN_REF: origin/main
|
||||
run: bash "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -9,6 +9,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
- uses: actions/labeler@f27b608878404679385c85cfa523b85ccb86e213 # v6.1.0
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
18
.github/workflows/main-ci.yml
vendored
18
.github/workflows/main-ci.yml
vendored
@ -55,11 +55,8 @@ jobs:
|
||||
api:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.all'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/init-env.sh'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -86,19 +83,19 @@ jobs:
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/conftest.py'
|
||||
- 'api/tests/pytest_dify.py'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.all'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/init-env.sh'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.pytest.ports.yaml'
|
||||
- 'docker/docker-compose.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -118,9 +115,8 @@ jobs:
|
||||
- 'api/migrations/**'
|
||||
- 'api/.env.example'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
|
||||
22
.github/workflows/pyrefly-diff-comment.yml
vendored
22
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -77,10 +77,28 @@ jobs:
|
||||
}
|
||||
|
||||
if (diff.trim()) {
|
||||
await github.rest.issues.createComment({
|
||||
const body = '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>';
|
||||
const marker = '### Pyrefly Diff';
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
|
||||
});
|
||||
const existing = comments.find((comment) => comment.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
21
.github/workflows/pyrefly-diff.yml
vendored
21
.github/workflows/pyrefly-diff.yml
vendored
@ -103,9 +103,26 @@ jobs:
|
||||
].join('\n')
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
const marker = '### Pyrefly Diff';
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
const existing = comments.find((comment) => comment.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
39
.github/workflows/vdb-tests-full.yml
vendored
39
.github/workflows/vdb-tests-full.yml
vendored
@ -48,14 +48,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -64,32 +56,13 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase" \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/*/tests/integration_tests
|
||||
|
||||
31
.github/workflows/vdb-tests.yml
vendored
31
.github/workflows/vdb-tests.yml
vendored
@ -45,14 +45,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -61,31 +53,14 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -250,5 +250,5 @@ scripts/stress-test/reports/
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
|
||||
.context/*
|
||||
.eslintcache
|
||||
|
||||
@ -9,6 +9,7 @@ The codebase is split into:
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
- **Dify Agent Backend** (`/dify-agent`): Backend services for managing and executing agent
|
||||
|
||||
## Backend Workflow
|
||||
|
||||
|
||||
84
Makefile
84
Makefile
@ -3,6 +3,10 @@ DOCKER_REGISTRY=langgenius
|
||||
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
DOCKER_DIR=docker
|
||||
DOCKER_MIDDLEWARE_ENV=$(DOCKER_DIR)/middleware.env
|
||||
DOCKER_MIDDLEWARE_ENV_EXAMPLE=$(DOCKER_DIR)/envs/middleware.env.example
|
||||
DOCKER_MIDDLEWARE_PROJECT=dify-middlewares-dev
|
||||
|
||||
# Default target - show help
|
||||
.DEFAULT_GOAL := help
|
||||
@ -17,8 +21,13 @@ dev-setup: prepare-docker prepare-web prepare-api
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@if [ ! -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
|
||||
cp "$(DOCKER_MIDDLEWARE_ENV_EXAMPLE)" "$(DOCKER_MIDDLEWARE_ENV)"; \
|
||||
echo "Docker middleware.env created"; \
|
||||
else \
|
||||
echo "Docker middleware.env already exists"; \
|
||||
fi
|
||||
@cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
@ -39,12 +48,18 @@ prepare-api:
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@if [ -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
|
||||
cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) down; \
|
||||
else \
|
||||
echo "Docker middleware.env does not exist, skipping compose down"; \
|
||||
fi
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/mysql
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf docker/volumes/sandbox/dependencies
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
@ -68,16 +83,15 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running core type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@ -86,7 +100,46 @@ test:
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
|
||||
echo "Running backend unit tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers; \
|
||||
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
|
||||
api/tests/unit_tests/controllers; \
|
||||
fi
|
||||
@echo "✅ Unit tests complete"
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running full backend test suite..."
|
||||
@if [ -n "$(TARGET_TESTS)" ]; then \
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
echo "Running backend unit tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers; \
|
||||
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
|
||||
api/tests/unit_tests/controllers; \
|
||||
echo "Running backend integration tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --start-middleware -n auto \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests; \
|
||||
echo "Running VDB smoke tests"; \
|
||||
uv run --project api --dev pytest --start-vdb \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests; \
|
||||
fi
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
@ -132,15 +185,16 @@ help:
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo " make dev-clean - Stop Docker middleware containers and remove dev data"
|
||||
@echo ""
|
||||
@echo "Backend Code Quality:"
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo " make test-all - Run full backend tests, including Docker-backed suites"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@ -150,4 +204,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
|
||||
11
README.md
11
README.md
@ -76,14 +76,7 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
./init-env.sh
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, initialize `.env`, then run `docker compose up -d` from the `docker` directory.
|
||||
|
||||
```powershell
|
||||
.\init-env.ps1
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
@ -144,7 +137,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, edit `docker/.env` after running the initialization script. The full reference remains in [`docker/.env.all`](docker/.env.all). After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, edit `docker/.env`. The essential startup defaults live in [`docker/.env.example`](docker/.env.example), and optional advanced variables are split under `docker/envs/` by theme. After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ TRIGGER_URL=http://localhost:5001
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
ENABLE_COLLABORATION_MODE=true
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
|
||||
# Ops trace retry configuration
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
@ -98,6 +102,8 @@ DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
# Connection pool reset behavior on return
|
||||
SQLALCHEMY_POOL_RESET_ON_RETURN=rollback
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -381,7 +387,7 @@ VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_SCHEME=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
@ -432,8 +438,6 @@ UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
@ -553,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
|
||||
@ -180,6 +180,8 @@ Quick checks while iterating:
|
||||
- Format: `make format`
|
||||
- Lint (includes auto-fix): `make lint`
|
||||
- Type check: `make type-check`
|
||||
- Unit tests: `make test`
|
||||
- Full backend tests, including Docker-backed suites: `make test-all`
|
||||
- Targeted tests: `make test TARGET_TESTS=./api/tests/<target_tests>`
|
||||
|
||||
Before opening a PR / submitting:
|
||||
@ -193,6 +195,10 @@ Before opening a PR / submitting:
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
and avoid adding new legacy `ns.model(...)`, `@marshal_with(...)`, or GET `@ns.expect(...)` patterns.
|
||||
|
||||
### Miscellaneous
|
||||
|
||||
|
||||
@ -24,7 +24,8 @@ RUN apt-get update \
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -99,7 +99,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
uv run pyrefly check # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
@ -117,7 +117,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
# Capture the decorator return values so static checkers do not treat the hooks as unused.
|
||||
_ = before_request
|
||||
_ = add_trace_headers
|
||||
|
||||
@ -181,7 +181,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
ext_database,
|
||||
@ -189,6 +188,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_set_secretkey,
|
||||
ext_logstore, # Initialize logstore after storage, before celery
|
||||
ext_celery,
|
||||
ext_login,
|
||||
|
||||
1
api/clients/__init__.py
Normal file
1
api/clients/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""External service client packages."""
|
||||
74
api/clients/agent_backend/__init__.py
Normal file
74
api/clients/agent_backend/__init__.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""API-side integration boundary for the Dify Agent backend.
|
||||
|
||||
Public wire DTOs come from ``dify_agent.protocol``. This package only contains
|
||||
API adapters: request building from Dify product concepts, a thin client wrapper,
|
||||
event adaptation for future workflow integration, and deterministic fakes.
|
||||
"""
|
||||
|
||||
from clients.agent_backend.client import AgentBackendRunClient, DifyAgentBackendRunClient
|
||||
from clients.agent_backend.errors import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
AgentBackendRequestBuildError,
|
||||
AgentBackendRunFailedError,
|
||||
AgentBackendStreamError,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
)
|
||||
from clients.agent_backend.event_adapter import (
|
||||
AgentBackendInternalEvent,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunCancelledInternalEvent,
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendRunFailedInternalEvent,
|
||||
AgentBackendRunPausedInternalEvent,
|
||||
AgentBackendRunStartedInternalEvent,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
AgentBackendStreamInternalEvent,
|
||||
)
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
"AgentBackendInternalEventType",
|
||||
"AgentBackendModelConfig",
|
||||
"AgentBackendOutputConfig",
|
||||
"AgentBackendRequestBuildError",
|
||||
"AgentBackendRunCancelledInternalEvent",
|
||||
"AgentBackendRunClient",
|
||||
"AgentBackendRunEventAdapter",
|
||||
"AgentBackendRunFailedError",
|
||||
"AgentBackendRunFailedInternalEvent",
|
||||
"AgentBackendRunPausedInternalEvent",
|
||||
"AgentBackendRunRequestBuilder",
|
||||
"AgentBackendRunStartedInternalEvent",
|
||||
"AgentBackendRunSucceededInternalEvent",
|
||||
"AgentBackendStreamError",
|
||||
"AgentBackendStreamInternalEvent",
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
130
api/clients/agent_backend/client.py
Normal file
130
api/clients/agent_backend/client.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""Synchronous API-side wrapper around the public ``dify-agent`` client.
|
||||
|
||||
``dify-agent`` owns the cross-service DTOs and HTTP/SSE implementation. The API
|
||||
backend keeps this thin wrapper so workflow code depends on a local protocol,
|
||||
gets API-native errors, and can use a deterministic fake in tests without
|
||||
creating another wire contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Protocol
|
||||
|
||||
from dify_agent.client import (
|
||||
DifyAgentClientError,
|
||||
DifyAgentHTTPError,
|
||||
DifyAgentStreamError,
|
||||
DifyAgentTimeoutError,
|
||||
DifyAgentValidationError,
|
||||
)
|
||||
from dify_agent.protocol import (
|
||||
CancelRunRequest,
|
||||
CancelRunResponse,
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
RunEvent,
|
||||
RunStatusResponse,
|
||||
)
|
||||
|
||||
from clients.agent_backend.errors import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
AgentBackendStreamError,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
)
|
||||
|
||||
|
||||
class AgentBackendRunClient(Protocol):
|
||||
"""Local boundary used by API workflow integrations to run Agent backend jobs."""
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one Agent backend run and return its accepted status."""
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Request explicit cancellation for one Agent backend run."""
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Yield public ``dify-agent`` run events in stream order."""
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Wait for a run to reach a terminal status and return that status."""
|
||||
|
||||
|
||||
class _DifyAgentSyncClient(Protocol):
|
||||
"""Subset of ``dify_agent.client.Client`` used by the API wrapper."""
|
||||
|
||||
def create_run_sync(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one run synchronously."""
|
||||
|
||||
def cancel_run_sync(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Cancel one run synchronously."""
|
||||
|
||||
def stream_events_sync(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Stream run events synchronously."""
|
||||
|
||||
def wait_run_sync(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Wait for terminal run status synchronously."""
|
||||
|
||||
|
||||
class DifyAgentBackendRunClient:
|
||||
"""Adapter from API sync call sites to ``dify_agent.client.Client`` sync methods."""
|
||||
|
||||
client: _DifyAgentSyncClient
|
||||
|
||||
def __init__(self, client: _DifyAgentSyncClient) -> None:
|
||||
self.client = client
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one run through ``POST /runs`` and normalize client exceptions."""
|
||||
try:
|
||||
return self.client.create_run_sync(request)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Cancel one run through ``POST /runs/{run_id}/cancel`` and normalize exceptions."""
|
||||
try:
|
||||
return self.client.cancel_run_sync(run_id, request=request)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Stream run events from ``/events/sse`` with the wrapped client's reconnect policy."""
|
||||
try:
|
||||
yield from self.client.stream_events_sync(run_id, after=after)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Poll run status until terminal state and normalize client exceptions."""
|
||||
try:
|
||||
return self.client.wait_run_sync(run_id, timeout_seconds=timeout_seconds)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
|
||||
def _normalize_dify_agent_error(exc: Exception) -> AgentBackendError:
|
||||
"""Map public ``dify-agent`` client errors to API-side integration errors."""
|
||||
match exc:
|
||||
case DifyAgentValidationError() as error:
|
||||
return AgentBackendValidationError(
|
||||
"Agent backend request or response validation failed", detail=error.detail
|
||||
)
|
||||
case DifyAgentHTTPError() as error:
|
||||
return AgentBackendHTTPError(
|
||||
f"Agent backend HTTP {error.status_code}",
|
||||
status_code=error.status_code,
|
||||
detail=error.detail,
|
||||
)
|
||||
case DifyAgentTimeoutError() as error:
|
||||
return AgentBackendTransportError(str(error))
|
||||
case DifyAgentStreamError() as error:
|
||||
return AgentBackendStreamError(str(error))
|
||||
case DifyAgentClientError() as error:
|
||||
return AgentBackendTransportError(str(error))
|
||||
case AgentBackendError() as error:
|
||||
return error
|
||||
case _:
|
||||
return AgentBackendTransportError(str(exc) or type(exc).__name__)
|
||||
61
api/clients/agent_backend/errors.py
Normal file
61
api/clients/agent_backend/errors.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""API-side errors for the Dify Agent backend integration.
|
||||
|
||||
The wire protocol and low-level HTTP behaviour are owned by ``dify-agent``.
|
||||
This module only normalizes those client errors into the API backend's boundary
|
||||
so workflow/node code does not depend directly on transport-specific exception
|
||||
classes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AgentBackendError(Exception):
|
||||
"""Base error for API-side Agent backend integration failures."""
|
||||
|
||||
|
||||
class AgentBackendRequestBuildError(AgentBackendError):
|
||||
"""Raised when Dify product/workflow state cannot be mapped to a run request."""
|
||||
|
||||
|
||||
class AgentBackendTransportError(AgentBackendError):
|
||||
"""Raised for timeout or request-level failures talking to Agent backend."""
|
||||
|
||||
|
||||
class AgentBackendHTTPError(AgentBackendTransportError):
|
||||
"""Raised for Agent backend HTTP errors after status/detail normalization."""
|
||||
|
||||
status_code: int
|
||||
detail: object
|
||||
|
||||
def __init__(self, message: str, *, status_code: int, detail: object) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentBackendValidationError(AgentBackendError):
|
||||
"""Raised for local request validation or Agent backend 422 responses."""
|
||||
|
||||
detail: object
|
||||
|
||||
def __init__(self, message: str, *, detail: object) -> None:
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentBackendStreamError(AgentBackendError):
|
||||
"""Raised when an Agent backend event stream is malformed or exhausted."""
|
||||
|
||||
|
||||
class AgentBackendRunFailedError(AgentBackendError):
|
||||
"""Raised by callers that choose to translate a terminal failed run into an exception."""
|
||||
|
||||
run_id: str
|
||||
detail: Any
|
||||
|
||||
def __init__(self, run_id: str, detail: Any) -> None:
|
||||
self.run_id = run_id
|
||||
self.detail = detail
|
||||
super().__init__(f"Agent backend run failed: {run_id}")
|
||||
167
api/clients/agent_backend/event_adapter.py
Normal file
167
api/clients/agent_backend/event_adapter.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""Adapt public ``dify-agent`` run events into API-internal event semantics.
|
||||
|
||||
The adapter does not define a new cross-service event contract. It consumes
|
||||
``dify_agent.protocol.RunEvent`` and produces small API-internal models that the
|
||||
future workflow Agent Node can map to Graphon/AppQueue events in phase 3.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.protocol import (
|
||||
PydanticAIStreamRunEvent,
|
||||
RunCancelledEvent,
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunPausedEvent,
|
||||
RunStartedEvent,
|
||||
RunSucceededEvent,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
|
||||
|
||||
_EVENT_DATA_ADAPTER = TypeAdapter(object)
|
||||
|
||||
|
||||
class AgentBackendInternalEventType(StrEnum):
|
||||
"""API-only event labels used before Graphon/AppQueue integration."""
|
||||
|
||||
RUN_STARTED = "run_started"
|
||||
STREAM_EVENT = "stream_event"
|
||||
RUN_PAUSED = "run_paused"
|
||||
RUN_SUCCEEDED = "run_succeeded"
|
||||
RUN_FAILED = "run_failed"
|
||||
RUN_CANCELLED = "run_cancelled"
|
||||
|
||||
|
||||
class AgentBackendInternalEventBase(BaseModel):
|
||||
"""Common fields preserved from public Dify Agent run events."""
|
||||
|
||||
run_id: str
|
||||
source_event_id: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AgentBackendRunStartedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal marker for a started Agent backend run."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_STARTED] = AgentBackendInternalEventType.RUN_STARTED
|
||||
|
||||
|
||||
class AgentBackendStreamInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal wrapper for one pydantic-ai stream event payload."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.STREAM_EVENT] = AgentBackendInternalEventType.STREAM_EVENT
|
||||
event_kind: str | None = None
|
||||
data: JsonValue
|
||||
|
||||
|
||||
class AgentBackendRunSucceededInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal success event carrying final output and session state."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_SUCCEEDED] = AgentBackendInternalEventType.RUN_SUCCEEDED
|
||||
output: JsonValue
|
||||
session_snapshot: CompositorSessionSnapshot
|
||||
|
||||
|
||||
class AgentBackendRunPausedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal resumable pause event for human handoff and Babysit flows."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_PAUSED] = AgentBackendInternalEventType.RUN_PAUSED
|
||||
reason: str
|
||||
message: str | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
class AgentBackendRunFailedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal failure event carrying the backend-safe error text."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_FAILED] = AgentBackendInternalEventType.RUN_FAILED
|
||||
error: str
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class AgentBackendRunCancelledInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal cancellation event."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_CANCELLED] = AgentBackendInternalEventType.RUN_CANCELLED
|
||||
reason: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
type AgentBackendInternalEvent = Annotated[
|
||||
AgentBackendRunStartedInternalEvent
|
||||
| AgentBackendStreamInternalEvent
|
||||
| AgentBackendRunPausedInternalEvent
|
||||
| AgentBackendRunSucceededInternalEvent
|
||||
| AgentBackendRunFailedInternalEvent
|
||||
| AgentBackendRunCancelledInternalEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AgentBackendRunEventAdapter:
|
||||
"""Maps public ``dify-agent`` event variants to API-internal event variants."""
|
||||
|
||||
def adapt(self, event: RunEvent) -> list[AgentBackendInternalEvent]:
|
||||
"""Return zero or more API-internal events derived from one public run event."""
|
||||
match event:
|
||||
case RunStartedEvent():
|
||||
return [
|
||||
AgentBackendRunStartedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
)
|
||||
]
|
||||
case PydanticAIStreamRunEvent():
|
||||
data = cast(JsonValue, _EVENT_DATA_ADAPTER.dump_python(event.data, mode="json"))
|
||||
event_kind = data.get("event_kind") if isinstance(data, dict) else None
|
||||
return [
|
||||
AgentBackendStreamInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
event_kind=event_kind if isinstance(event_kind, str) else None,
|
||||
data=data,
|
||||
)
|
||||
]
|
||||
case RunSucceededEvent():
|
||||
return [
|
||||
AgentBackendRunSucceededInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
output=event.data.output,
|
||||
session_snapshot=event.data.session_snapshot,
|
||||
)
|
||||
]
|
||||
case RunPausedEvent():
|
||||
return [
|
||||
AgentBackendRunPausedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
reason=event.data.reason,
|
||||
message=event.data.message,
|
||||
session_snapshot=event.data.session_snapshot,
|
||||
)
|
||||
]
|
||||
case RunFailedEvent():
|
||||
return [
|
||||
AgentBackendRunFailedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
error=event.data.error,
|
||||
reason=event.data.reason,
|
||||
)
|
||||
]
|
||||
case RunCancelledEvent():
|
||||
return [
|
||||
AgentBackendRunCancelledInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
reason=event.data.reason,
|
||||
message=event.data.message,
|
||||
)
|
||||
]
|
||||
raise TypeError(f"unsupported agent backend run event: {type(event).__name__}")
|
||||
22
api/clients/agent_backend/factory.py
Normal file
22
api/clients/agent_backend/factory.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Factories for API-side Agent backend clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dify_agent.client import Client
|
||||
|
||||
from clients.agent_backend.client import AgentBackendRunClient, DifyAgentBackendRunClient
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
|
||||
|
||||
def create_agent_backend_run_client(
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
use_fake: bool = False,
|
||||
fake_scenario: str | FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
|
||||
) -> AgentBackendRunClient:
|
||||
"""Create the API-side run client without hiding the ``dify-agent`` protocol."""
|
||||
if use_fake:
|
||||
return FakeAgentBackendRunClient(scenario=FakeAgentBackendScenario(fake_scenario))
|
||||
if base_url is None:
|
||||
raise ValueError("base_url is required when creating a real Agent backend client")
|
||||
return DifyAgentBackendRunClient(Client(base_url=base_url))
|
||||
117
api/clients/agent_backend/fake_client.py
Normal file
117
api/clients/agent_backend/fake_client.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""Deterministic fake Agent backend client using public ``dify-agent`` events.
|
||||
|
||||
Tests should exercise the same ``RunEvent`` DTOs as the real HTTP client. This
|
||||
fake therefore replaces the previous custom mock protocol instead of emulating a
|
||||
separate ``agent-backend.v1`` event stream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.protocol import (
|
||||
CancelRunRequest,
|
||||
CancelRunResponse,
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
RunSucceededEventData,
|
||||
)
|
||||
|
||||
_FIXED_TIME = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
|
||||
|
||||
class FakeAgentBackendScenario(StrEnum):
|
||||
"""Deterministic fake scenarios for API-side integration tests."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
"""In-memory implementation of ``AgentBackendRunClient`` for unit tests."""
|
||||
|
||||
scenario: FakeAgentBackendScenario
|
||||
run_id: str
|
||||
request: CreateRunRequest | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
|
||||
run_id: str = "fake-run-1",
|
||||
) -> None:
|
||||
self.scenario = scenario
|
||||
self.run_id = run_id
|
||||
self.request = None
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Record the request and return a deterministic accepted response."""
|
||||
self.request = request
|
||||
return CreateRunResponse(run_id=self.run_id, status="running")
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Return a deterministic cancellation response."""
|
||||
del request
|
||||
return CancelRunResponse(run_id=run_id, status="cancelled")
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Yield the deterministic public ``RunEvent`` sequence for ``run_id``."""
|
||||
for event in self._events(run_id):
|
||||
if after is not None and event.id is not None and event.id <= after:
|
||||
continue
|
||||
yield event
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Return a deterministic terminal status; timeout is accepted for protocol parity."""
|
||||
del timeout_seconds
|
||||
match self.scenario:
|
||||
case FakeAgentBackendScenario.SUCCESS:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="succeeded",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
case FakeAgentBackendScenario.FAILED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="failed",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
case FakeAgentBackendScenario.SUCCESS:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunSucceededEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunSucceededEventData(
|
||||
output={"text": "hello agent"},
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.FAILED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunFailedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
192
api/clients/agent_backend/request_builder.py
Normal file
192
api/clients/agent_backend/request_builder.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Build ``dify-agent`` run requests from API-side product concepts.
|
||||
|
||||
This module is intentionally an adapter, not a wire DTO package. The emitted
|
||||
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
|
||||
protocol has a single owner. API-only context such as Agent Soul vs workflow job
|
||||
prompt is preserved in layer names and metadata until the dedicated product
|
||||
schemas land in later phases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLayerConfig,
|
||||
DifyPluginLLMLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
ExecutionContext,
|
||||
LayerExitSignals,
|
||||
RunComposition,
|
||||
RunLayerSpec,
|
||||
RunPurpose,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
model_provider: str
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentBackendOutputConfig(BaseModel):
|
||||
"""API-side structured output declaration for the conventional output layer."""
|
||||
|
||||
json_schema: dict[str, JsonValue]
|
||||
name: str = "final_result"
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: ExecutionContext
|
||||
workflow_node_job_prompt: str
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "workflow_node"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
suspend_on_exit: bool = False
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("workflow_node_job_prompt", "user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "workflow_node_job"},
|
||||
config=PromptLayerConfig(prefix=run_input.workflow_node_job_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "workflow_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLayerConfig(
|
||||
tenant_id=run_input.model.tenant_id,
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
user_id=run_input.model.user_id,
|
||||
),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
name=run_input.output.name,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
execution_context=run_input.execution_context,
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_SENSITIVE_KEY_PARTS = ("secret", "credential", "token", "password", "api_key")
|
||||
|
||||
|
||||
def redact_for_agent_backend_log(value: object) -> object:
|
||||
"""Return a JSON-like copy with credential-bearing keys redacted for logs/tests."""
|
||||
if isinstance(value, BaseModel):
|
||||
return redact_for_agent_backend_log(value.model_dump(mode="json", warnings=False))
|
||||
if isinstance(value, dict):
|
||||
redacted: dict[object, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key).lower()
|
||||
if any(part in key_text for part in _SENSITIVE_KEY_PARTS):
|
||||
redacted[key] = "[REDACTED]"
|
||||
else:
|
||||
redacted[key] = redact_for_agent_backend_log(item)
|
||||
return redacted
|
||||
if isinstance(value, list):
|
||||
return [redact_for_agent_backend_log(item) for item in value]
|
||||
return value
|
||||
@ -185,9 +185,9 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
|
||||
@ -14,6 +14,7 @@ from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.model import App, AppMode, Conversation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,13 +24,16 @@ DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"After the reset, all LLM credentials and tool provider credentials "
|
||||
"(builtin / API / MCP) will be purged, requiring re-entry. "
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? "
|
||||
"This will also purge builtin / API / MCP tool provider records for every tenant. "
|
||||
"This operation cannot be rolled back!",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
@ -53,6 +57,13 @@ def reset_encrypt_key_pair():
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
# Purge tool provider records that hold credentials encrypted under the
|
||||
# tenant key. Leaving them in place causes /console/api/workspaces/current/
|
||||
# tool-providers to 500 because decryption fails on stale ciphertext (#35396).
|
||||
session.execute(delete(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
|
||||
@ -23,6 +23,12 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK: bool = Field(
|
||||
default=False,
|
||||
description="If disabled, credential policy check is only performed when saving workflows."
|
||||
"This helps gain runtime performance by trading off consistency.",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -23,9 +23,9 @@ class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
SECRET_KEY: str = Field(
|
||||
description="Secret key for secure session cookie signing."
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||
description="Secret key for secure session cookie signing. "
|
||||
"Leave empty to let Dify generate a persistent key in the storage directory, "
|
||||
"or set a strong value via the `SECRET_KEY` environment variable.",
|
||||
default="",
|
||||
)
|
||||
|
||||
@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class OpsTraceConfig(BaseSettings):
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum retry attempts for transient ops trace provider dispatch failures.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field(
|
||||
description="Delay in seconds between transient ops trace provider dispatch retry attempts.",
|
||||
default=5,
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
|
||||
@ -1298,7 +1310,7 @@ class PositionConfig(BaseSettings):
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
@ -1417,6 +1429,7 @@ class FeatureConfig(
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
OpsTraceConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
@ -50,28 +50,30 @@ from .vdb.vastbase_vector_config import VastbaseVectorConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
_VALID_STORAGE_TYPE = Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
]
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
STORAGE_TYPE: _VALID_STORAGE_TYPE = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
default=cast(_VALID_STORAGE_TYPE, "opendal"),
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
@ -114,7 +116,7 @@ class SQLAlchemyEngineOptionsDict(TypedDict):
|
||||
pool_pre_ping: bool
|
||||
connect_args: dict[str, str]
|
||||
pool_use_lifo: bool
|
||||
pool_reset_on_return: None
|
||||
pool_reset_on_return: Literal["commit", "rollback", None]
|
||||
pool_timeout: int
|
||||
|
||||
|
||||
@ -223,6 +225,11 @@ class DatabaseConfig(BaseSettings):
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RESET_ON_RETURN: Literal["commit", "rollback", None] = Field(
|
||||
description="Connection pool reset behavior on return. Options: 'commit', 'rollback', or None",
|
||||
default="rollback",
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count() or 1,
|
||||
@ -252,7 +259,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
"pool_reset_on_return": self.SQLALCHEMY_POOL_RESET_ON_RETURN,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
return result
|
||||
|
||||
38
api/configs/secret_key.py
Normal file
38
api/configs/secret_key.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""SECRET_KEY persistence helpers for runtime setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
GENERATED_SECRET_KEY_FILENAME = ".dify_secret_key"
|
||||
|
||||
|
||||
def resolve_secret_key(secret_key: str) -> str:
|
||||
"""Return an explicit SECRET_KEY or a generated key persisted in storage."""
|
||||
if secret_key:
|
||||
return secret_key
|
||||
|
||||
return _load_or_create_secret_key()
|
||||
|
||||
|
||||
def _load_or_create_secret_key() -> str:
|
||||
try:
|
||||
persisted_key = storage.load_once(GENERATED_SECRET_KEY_FILENAME).decode("utf-8").strip()
|
||||
if persisted_key:
|
||||
return persisted_key
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
generated_key = secrets.token_urlsafe(48)
|
||||
|
||||
try:
|
||||
storage.save(GENERATED_SECRET_KEY_FILENAME, f"{generated_key}\n".encode())
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"SECRET_KEY is not set and could not be generated at {GENERATED_SECRET_KEY_FILENAME}. "
|
||||
"Set SECRET_KEY explicitly or make storage writable."
|
||||
) from exc
|
||||
|
||||
return generated_key
|
||||
91
api/conftest.py
Normal file
91
api/conftest.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""Global pytest hooks for Dify backend tests.
|
||||
|
||||
This root conftest is loaded before package-specific conftests, which lets tests opt
|
||||
into Docker-backed middleware before application modules read environment config.
|
||||
It intentionally lives at the API root because pytest applies conftest.py files to
|
||||
tests below their directory, and this setup is shared by api/tests and api/providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.pytest_dify import (
|
||||
DEFAULT_MIDDLEWARE_SERVICES,
|
||||
DEFAULT_VDB_SERVICES,
|
||||
DockerComposeStack,
|
||||
build_middleware_stack,
|
||||
build_vdb_stack,
|
||||
ensure_backend_test_environment,
|
||||
ensure_compose_env_files,
|
||||
parse_services,
|
||||
)
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
_DIFY_COMPOSE_STACKS_KEY = pytest.StashKey[list[DockerComposeStack]]()
|
||||
|
||||
# This must run at import time because package-specific conftests can import the
|
||||
# Flask app before pytest_configure hooks from this file are called.
|
||||
ensure_backend_test_environment(_REPO_ROOT)
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("dify")
|
||||
group.addoption(
|
||||
"--start-middleware",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start the Docker middleware services needed by API integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--middleware-services",
|
||||
default=",".join(DEFAULT_MIDDLEWARE_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.middleware.yaml to start.",
|
||||
)
|
||||
group.addoption(
|
||||
"--start-vdb",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start vector-store Docker services for VDB integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--vdb-services",
|
||||
default=",".join(DEFAULT_VDB_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.yaml to start for VDB tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
config = session.config
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks: list[DockerComposeStack] = []
|
||||
if config.getoption("start_middleware"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_middleware_stack(_REPO_ROOT, parse_services(config.getoption("middleware_services")))
|
||||
stack.up()
|
||||
stacks.append(stack)
|
||||
|
||||
if config.getoption("start_vdb"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_vdb_stack(_REPO_ROOT, parse_services(config.getoption("vdb_services")))
|
||||
stack.up()
|
||||
stacks.append(stack)
|
||||
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = stacks
|
||||
|
||||
|
||||
def pytest_unconfigure(config: pytest.Config) -> None:
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks = config.stash.get(_DIFY_COMPOSE_STACKS_KEY, [])
|
||||
for stack in reversed(stacks):
|
||||
stack.down()
|
||||
211
api/controllers/API_SCHEMA_GUIDE.md
Normal file
211
api/controllers/API_SCHEMA_GUIDE.md
Normal file
@ -0,0 +1,211 @@
|
||||
# API Schema Guide
|
||||
|
||||
This guide describes the expected Flask-RESTX + Pydantic pattern for controller request payloads, query
|
||||
parameters, response schemas, and Swagger documentation.
|
||||
|
||||
## Principles
|
||||
|
||||
- Use Pydantic `BaseModel` for request bodies and query parameters.
|
||||
- Use `fields.base.ResponseModel` for response DTOs.
|
||||
- Keep runtime validation and Swagger documentation wired to the same Pydantic model.
|
||||
- Prefer explicit validation and serialization in controller methods over Flask-RESTX marshalling.
|
||||
- Do not add new Flask-RESTX `fields.*` dictionaries, `Namespace.model(...)` exports, or `@marshal_with(...)` for migrated or new endpoints.
|
||||
- Do not use `@ns.expect(...)` for GET query parameters. Flask-RESTX documents that as a request body.
|
||||
|
||||
## Naming
|
||||
|
||||
- Request body models: use a `Payload` suffix.
|
||||
- Example: `WorkflowRunPayload`, `DatasourceVariablesPayload`.
|
||||
- Query parameter models: use a `Query` suffix.
|
||||
- Example: `WorkflowRunListQuery`, `MessageListQuery`.
|
||||
- Response models: use a `Response` suffix and inherit from `ResponseModel`.
|
||||
- Example: `WorkflowRunDetailResponse`, `WorkflowRunNodeExecutionListResponse`.
|
||||
- Use `ListResponse` or `PaginationResponse` for wrapper responses.
|
||||
- Example: `WorkflowRunNodeExecutionListResponse`, `WorkflowRunPaginationResponse`.
|
||||
- Keep these models near the controller when they are endpoint-specific. Move them to `fields/*_fields.py` only when shared by multiple controllers.
|
||||
|
||||
## Registering Models For Swagger
|
||||
|
||||
Use helpers from `controllers.common.schema`.
|
||||
|
||||
```python
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
```
|
||||
|
||||
Register request payload and query models with `register_schema_models(...)`:
|
||||
|
||||
```python
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowRunPayload,
|
||||
WorkflowRunListQuery,
|
||||
)
|
||||
```
|
||||
|
||||
Register response models with `register_response_schema_models(...)`:
|
||||
|
||||
```python
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
)
|
||||
```
|
||||
|
||||
Response models are registered in Pydantic serialization mode. This matters when a response model uses
|
||||
`validation_alias` to read internal object attributes but emits public API field names. For example, a response model
|
||||
can validate from `inputs_dict` while documenting and serializing `inputs`.
|
||||
|
||||
## Request Bodies
|
||||
|
||||
For non-GET request bodies:
|
||||
|
||||
1. Define a Pydantic `Payload` model.
|
||||
2. Register it with `register_schema_models(...)`.
|
||||
3. Use `@ns.expect(ns.models[Payload.__name__])` for Swagger documentation.
|
||||
4. Validate from `ns.payload or {}` inside the controller.
|
||||
|
||||
```python
|
||||
class DraftWorkflowNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
|
||||
|
||||
register_schema_models(console_ns, DraftWorkflowNodeRunPayload)
|
||||
|
||||
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
result = service.run(..., inputs=payload.inputs, query=payload.query)
|
||||
return dump_response(WorkflowRunNodeExecutionResponse, result)
|
||||
```
|
||||
|
||||
## Query Parameters
|
||||
|
||||
For GET query parameters:
|
||||
|
||||
1. Define a Pydantic `Query` model.
|
||||
2. Register it with `register_schema_models(...)` if it is referenced elsewhere in docs, or only use
|
||||
`query_params_from_model(...)` if a body schema is not needed.
|
||||
3. Use `@ns.doc(params=query_params_from_model(QueryModel))`.
|
||||
4. Validate from `request.args.to_dict(flat=True)` or an explicit dict when type coercion is needed.
|
||||
|
||||
```python
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
|
||||
@console_ns.doc(params=query_params_from_model(WorkflowRunListQuery))
|
||||
def get(self, app_model: App):
|
||||
query = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
result = service.list(..., limit=query.limit, last_id=query.last_id)
|
||||
return dump_response(WorkflowRunPaginationResponse, result)
|
||||
```
|
||||
|
||||
Do not do this for GET query parameters:
|
||||
|
||||
```python
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
def get(...):
|
||||
...
|
||||
```
|
||||
|
||||
That documents a GET request body and is not the expected contract.
|
||||
|
||||
## Responses
|
||||
|
||||
Response models should inherit from `ResponseModel`:
|
||||
|
||||
```python
|
||||
class WorkflowRunNodeExecutionResponse(ResponseModel):
|
||||
id: str
|
||||
inputs: Any = Field(default=None, validation_alias="inputs_dict")
|
||||
process_data: Any = Field(default=None, validation_alias="process_data_dict")
|
||||
outputs: Any = Field(default=None, validation_alias="outputs_dict")
|
||||
```
|
||||
|
||||
Document response models with `@ns.response(...)`:
|
||||
|
||||
```python
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node run started successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionResponse.__name__],
|
||||
)
|
||||
def post(...):
|
||||
...
|
||||
```
|
||||
|
||||
Serialize explicitly:
|
||||
|
||||
```python
|
||||
return dump_response(WorkflowRunNodeExecutionResponse, workflow_node_execution)
|
||||
```
|
||||
|
||||
`dump_response(...)` is the preferred response serialization helper for a single Pydantic response DTO. It validates
|
||||
with `from_attributes=True` and returns `model_dump(mode="json")`, so SQLAlchemy models, plain objects, dictionaries,
|
||||
Pydantic aliases, computed fields, and `datetime` values are serialized consistently.
|
||||
|
||||
For wrapper responses, pass a dictionary with the public wrapper fields:
|
||||
|
||||
```python
|
||||
return dump_response(
|
||||
WorkflowRunPaginationResponse,
|
||||
{
|
||||
"data": workflow_runs,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
If the service can return `None`, translate that into the expected HTTP error before validation:
|
||||
|
||||
```python
|
||||
workflow_run = service.get_workflow_run(...)
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return dump_response(WorkflowRunDetailResponse, workflow_run)
|
||||
```
|
||||
|
||||
Use manual `model_validate(...).model_dump(...)` only when the endpoint needs behavior that `dump_response(...)` does
|
||||
not provide, such as returning a non-dict payload, intentionally excluding fields, or composing a `(body, status)` tuple.
|
||||
|
||||
## Legacy Flask-RESTX Patterns
|
||||
|
||||
Avoid adding these patterns to new or migrated endpoints:
|
||||
|
||||
- `ns.model(...)` for new request/response DTOs.
|
||||
- Module-level exported RESTX model objects such as `workflow_run_detail_model`.
|
||||
- `fields.Nested({...})` with raw inline dict field maps.
|
||||
- `@marshal_with(...)` for response serialization.
|
||||
- `@ns.expect(...)` for GET query params.
|
||||
|
||||
Existing legacy field dictionaries may remain where an endpoint has not yet been migrated. Keep that compatibility local
|
||||
to the legacy area and avoid importing RESTX model objects from controllers.
|
||||
|
||||
## Verifying Swagger
|
||||
|
||||
For schema and documentation changes, run focused tests and generate Swagger JSON:
|
||||
|
||||
```bash
|
||||
uv run --project . pytest tests/unit_tests/controllers/common/test_schema.py
|
||||
uv run --project . pytest tests/unit_tests/commands/test_generate_swagger_specs.py tests/unit_tests/controllers/test_swagger.py
|
||||
uv run --project . dev/generate_swagger_specs.py --output-dir /tmp/dify-openapi-check
|
||||
```
|
||||
|
||||
Inspect affected endpoints with `jq`. Check that:
|
||||
|
||||
- GET parameters are `in: query`.
|
||||
- Request bodies appear only where the endpoint has a body.
|
||||
- Responses reference the expected `*Response` schema.
|
||||
- Response schemas use public serialized names, not internal validation aliases like `inputs_dict`.
|
||||
@ -2,8 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
@ -19,6 +20,113 @@ class SystemParameters(BaseModel):
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class SimpleResultResponse(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class SimpleResultMessageResponse(ResponseModel):
|
||||
result: str
|
||||
message: str
|
||||
|
||||
|
||||
class SimpleMessageResponse(ResponseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class SimpleDataResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
class SimpleResultDataResponse(ResponseModel):
|
||||
result: str
|
||||
data: str
|
||||
|
||||
|
||||
class SimpleResultStringListResponse(ResponseModel):
|
||||
result: str
|
||||
data: list[str]
|
||||
|
||||
|
||||
class SimpleResultOptionalDataResponse(ResponseModel):
|
||||
result: str
|
||||
data: str | None = None
|
||||
|
||||
|
||||
class AccessTokenData(ResponseModel):
|
||||
access_token: str
|
||||
|
||||
|
||||
class AccessTokenResultResponse(ResponseModel):
|
||||
result: str
|
||||
data: AccessTokenData
|
||||
|
||||
|
||||
class VerificationTokenResponse(ResponseModel):
|
||||
is_valid: bool
|
||||
email: str
|
||||
token: str
|
||||
|
||||
|
||||
class LoginStatusResponse(ResponseModel):
|
||||
logged_in: bool
|
||||
app_logged_in: bool
|
||||
|
||||
|
||||
class AccessModeResponse(ResponseModel):
|
||||
access_mode: str = Field(serialization_alias="accessMode", validation_alias="accessMode")
|
||||
|
||||
|
||||
class BooleanResultResponse(ResponseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class SuccessResponse(ResponseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class UsageCheckResponse(ResponseModel):
|
||||
is_using: bool
|
||||
|
||||
|
||||
class UsageCountResponse(ResponseModel):
|
||||
is_using: bool
|
||||
count: int
|
||||
|
||||
|
||||
class IndexInfoResponse(ResponseModel):
|
||||
welcome: str
|
||||
api_version: str
|
||||
server_version: str
|
||||
|
||||
|
||||
class AvatarUrlResponse(ResponseModel):
|
||||
avatar_url: str
|
||||
|
||||
|
||||
class TextContentResponse(ResponseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class AllowedExtensionsResponse(ResponseModel):
|
||||
allowed_extensions: list[str]
|
||||
|
||||
|
||||
class UrlResponse(ResponseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class RedirectUrlResponse(ResponseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
class ApiBaseUrlResponse(ResponseModel):
|
||||
api_base_url: str
|
||||
|
||||
|
||||
class NewAppResponse(ResponseModel):
|
||||
new_app_id: str
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
|
||||
@ -1,6 +1,21 @@
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
|
||||
|
||||
def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
"""Serialize default values into strings expected by human-input form clients."""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
@ -1,6 +1,14 @@
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces.
|
||||
|
||||
Flask-RESTX treats `SchemaModel` bodies as opaque JSON schemas; it does not
|
||||
promote Pydantic's nested `$defs` into top-level Swagger `definitions`.
|
||||
These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@ -8,10 +16,89 @@ from pydantic import BaseModel, TypeAdapter
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||
QueryParamDoc = TypedDict(
|
||||
"QueryParamDoc",
|
||||
{
|
||||
"in": NotRequired[str],
|
||||
"type": NotRequired[str],
|
||||
"items": NotRequired[dict[str, object]],
|
||||
"required": NotRequired[bool],
|
||||
"description": NotRequired[str],
|
||||
"enum": NotRequired[list[object]],
|
||||
"default": NotRequired[object],
|
||||
"minimum": NotRequired[int | float],
|
||||
"maximum": NotRequired[int | float],
|
||||
"minLength": NotRequired[int],
|
||||
"maxLength": NotRequired[int],
|
||||
"minItems": NotRequired[int],
|
||||
"maxItems": NotRequired[int],
|
||||
},
|
||||
)
|
||||
|
||||
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
schema = _swagger_2_compatible_schema(schema)
|
||||
nested_definitions = schema.get("$defs")
|
||||
schema_to_register = dict(schema)
|
||||
if isinstance(nested_definitions, dict):
|
||||
schema_to_register.pop("$defs")
|
||||
|
||||
namespace.schema_model(name, schema_to_register)
|
||||
|
||||
if not isinstance(nested_definitions, dict):
|
||||
return
|
||||
|
||||
for nested_name, nested_schema in nested_definitions.items():
|
||||
if isinstance(nested_schema, dict):
|
||||
_register_json_schema(namespace, nested_name, nested_schema)
|
||||
|
||||
|
||||
JsonSchemaMode = Literal["validation", "serialization"]
|
||||
|
||||
|
||||
def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode: JsonSchemaMode) -> None:
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0, mode=mode),
|
||||
)
|
||||
|
||||
|
||||
def _swagger_2_compatible_schema(value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [_swagger_2_compatible_schema(item) for item in value]
|
||||
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
converted = {key: _swagger_2_compatible_schema(child) for key, child in value.items()}
|
||||
any_of = value.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return converted
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
has_null_candidate = any(isinstance(candidate, Mapping) and candidate.get("type") == "null" for candidate in any_of)
|
||||
if not has_null_candidate or len(non_null_candidates) != 1:
|
||||
return converted
|
||||
|
||||
non_null_schema = _swagger_2_compatible_schema(dict(non_null_candidates[0]))
|
||||
if not isinstance(non_null_schema, dict):
|
||||
return converted
|
||||
|
||||
converted.pop("anyOf", None)
|
||||
converted.update(non_null_schema)
|
||||
converted["x-nullable"] = True
|
||||
return converted
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
|
||||
|
||||
_register_schema_model(namespace, model, mode="validation")
|
||||
|
||||
|
||||
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
@ -21,6 +108,19 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
|
||||
register_schema_model(namespace, model)
|
||||
|
||||
|
||||
def register_response_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel using its serialized response shape."""
|
||||
|
||||
_register_schema_model(namespace, model, mode="serialization")
|
||||
|
||||
|
||||
def register_response_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
"""Register multiple response BaseModels using their serialized response shape."""
|
||||
|
||||
for model in models:
|
||||
register_response_schema_model(namespace, model)
|
||||
|
||||
|
||||
def get_or_create_model(model_name: str, field_def):
|
||||
# Import lazily to avoid circular imports between console controllers and schema helpers.
|
||||
from controllers.console import console_ns
|
||||
@ -34,15 +134,114 @@ def get_or_create_model(model_name: str, field_def):
|
||||
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
|
||||
"""Register multiple StrEnum with a namespace."""
|
||||
for model in models:
|
||||
namespace.schema_model(
|
||||
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
"""Build Flask-RESTX query parameter docs from a flat Pydantic model.
|
||||
|
||||
`Namespace.expect()` treats Pydantic schema models as request bodies, so GET
|
||||
endpoints should keep runtime validation on the Pydantic model and feed this
|
||||
derived mapping to `Namespace.doc(params=...)` for Swagger documentation.
|
||||
"""
|
||||
|
||||
schema = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
properties = schema.get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return {}
|
||||
|
||||
required = schema.get("required", [])
|
||||
required_names = set(required) if isinstance(required, list) else set()
|
||||
|
||||
params: dict[str, QueryParamDoc] = {}
|
||||
for name, property_schema in properties.items():
|
||||
if not isinstance(name, str) or not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
params[name] = _query_param_from_property(property_schema, required=name in required_names)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
|
||||
description = param_schema.get("description")
|
||||
if isinstance(description, str):
|
||||
param_doc["description"] = description
|
||||
|
||||
schema_type = param_schema.get("type")
|
||||
if isinstance(schema_type, str) and schema_type in {"array", "boolean", "integer", "number", "string"}:
|
||||
param_doc["type"] = schema_type
|
||||
if schema_type == "array":
|
||||
items = param_schema.get("items")
|
||||
if isinstance(items, Mapping):
|
||||
item_type = items.get("type")
|
||||
if isinstance(item_type, str):
|
||||
param_doc["items"] = {"type": item_type}
|
||||
|
||||
enum = param_schema.get("enum")
|
||||
if isinstance(enum, list):
|
||||
param_doc["enum"] = enum
|
||||
|
||||
default = param_schema.get("default")
|
||||
if default is not None:
|
||||
param_doc["default"] = default
|
||||
|
||||
minimum = param_schema.get("minimum")
|
||||
if isinstance(minimum, int | float):
|
||||
param_doc["minimum"] = minimum
|
||||
|
||||
maximum = param_schema.get("maximum")
|
||||
if isinstance(maximum, int | float):
|
||||
param_doc["maximum"] = maximum
|
||||
|
||||
min_length = param_schema.get("minLength")
|
||||
if isinstance(min_length, int):
|
||||
param_doc["minLength"] = min_length
|
||||
|
||||
max_length = param_schema.get("maxLength")
|
||||
if isinstance(max_length, int):
|
||||
param_doc["maxLength"] = max_length
|
||||
|
||||
min_items = param_schema.get("minItems")
|
||||
if isinstance(min_items, int):
|
||||
param_doc["minItems"] = min_items
|
||||
|
||||
max_items = param_schema.get("maxItems")
|
||||
if isinstance(max_items, int):
|
||||
param_doc["maxItems"] = max_items
|
||||
|
||||
return param_doc
|
||||
|
||||
|
||||
def _nullable_property_schema(property_schema: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
any_of = property_schema.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return property_schema
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
|
||||
if len(non_null_candidates) == 1:
|
||||
return {**property_schema, **non_null_candidates[0]}
|
||||
|
||||
return property_schema
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"register_enum_models",
|
||||
"register_response_schema_model",
|
||||
"register_response_schema_models",
|
||||
"register_schema_model",
|
||||
"register_schema_models",
|
||||
]
|
||||
|
||||
@ -33,7 +33,6 @@ for module_name in RESOURCE_MODULES:
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
@ -45,6 +44,8 @@ from . import (
|
||||
spec,
|
||||
version,
|
||||
)
|
||||
from .agent import composer as agent_composer
|
||||
from .agent import roster as agent_roster
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
@ -117,7 +118,7 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@ -142,10 +143,11 @@ api.add_namespace(console_ns)
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
"admin",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
"annotation",
|
||||
"api",
|
||||
"apikey",
|
||||
|
||||
@ -1,72 +1,11 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreBannerPayload.__name__,
|
||||
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@ -84,361 +23,3 @@ def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
else:
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
app_id=app.id,
|
||||
description=desc,
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
banner = ExporleBanner(
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
3
api/controllers/console/agent/__init__.py
Normal file
3
api/controllers/console/agent/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from . import composer, roster
|
||||
|
||||
__all__ = ["composer", "roster"]
|
||||
153
api/controllers/console/agent/composer.py
Normal file
153
api/controllers/console/agent/composer.py
Normal file
@ -0,0 +1,153 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
|
||||
class WorkflowAgentComposerApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
|
||||
class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
|
||||
class WorkflowAgentComposerImpactApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
|
||||
class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
|
||||
class AgentAppComposerApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
|
||||
class AgentAppComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def post(self, app_model):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
|
||||
class AgentAppComposerCandidatesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
130
api/controllers/console/agent/roster.py
Normal file
130
api/controllers/console/agent/roster.py
Normal file
@ -0,0 +1,130 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
|
||||
class AgentInviteOptionsQuery(RosterListQuery):
|
||||
app_id: str | None = Field(default=None, description="Workflow app id for in-current-workflow markers")
|
||||
|
||||
|
||||
class AgentIdPath(BaseModel):
|
||||
agent_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AgentInviteOptionsQuery,
|
||||
AgentIdPath,
|
||||
RosterAgentCreatePayload,
|
||||
RosterAgentUpdatePayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
|
||||
|
||||
def _agent_roster_service() -> AgentRosterService:
|
||||
return AgentRosterService(db.session)
|
||||
|
||||
|
||||
@console_ns.route("/agents")
|
||||
class AgentRosterListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
|
||||
|
||||
|
||||
@console_ns.route("/agents/invite-options")
|
||||
class AgentInviteOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>")
|
||||
class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions")
|
||||
class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
|
||||
class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id, version_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
)
|
||||
@ -11,6 +11,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -21,12 +22,6 @@ from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
@ -37,7 +32,7 @@ class ApiKeyItem(ResponseModel):
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
|
||||
@ -34,7 +34,7 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True))
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
|
||||
@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -10,8 +11,6 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
message_id: str = Field(..., description="Message UUID")
|
||||
@ -23,9 +22,7 @@ class AgentLogQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, AgentLogQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
@ -44,6 +41,6 @@ class AgentLogApi(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -33,8 +34,6 @@ from services.annotation_service import (
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AnnotationReplyPayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||
@ -87,17 +86,6 @@ class AnnotationFilePayload(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
def reg(model: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AnnotationReplyPayload)
|
||||
reg(AnnotationSettingUpdatePayload)
|
||||
reg(AnnotationListQuery)
|
||||
reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
@ -105,6 +93,13 @@ register_schema_models(
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationReplyPayload,
|
||||
AnnotationSettingUpdatePayload,
|
||||
AnnotationListQuery,
|
||||
CreateAnnotationPayload,
|
||||
UpdateAnnotationPayload,
|
||||
AnnotationReplyStatusQuery,
|
||||
AnnotationFilePayload,
|
||||
)
|
||||
|
||||
|
||||
@ -121,8 +116,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID, action: Literal["enable", "disable"]):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
@ -131,9 +125,9 @@ class AnnotationReplyActionApi(Resource):
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id))
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
result = AppAnnotationService.disable_app_annotation(str(app_id))
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -148,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
def get(self, app_id: UUID):
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id))
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -166,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID, annotation_setting_id):
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -189,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id, action):
|
||||
def get(self, app_id: UUID, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -217,14 +209,13 @@ class AnnotationApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
def get(self, app_id: UUID):
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword)
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
@ -246,8 +237,7 @@ class AnnotationApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID):
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
@ -258,15 +248,14 @@ class AnnotationApi(Resource):
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id))
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id):
|
||||
app_id = str(app_id)
|
||||
def delete(self, app_id: UUID):
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
annotation_ids = request.args.getlist("annotation_id")
|
||||
@ -280,11 +269,11 @@ class AnnotationApi(Resource):
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return result, 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
AppAnnotationService.clear_all_annotations(str(app_id))
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -303,9 +292,8 @@ class AnnotationExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
def get(self, app_id: UUID):
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id))
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
@ -331,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
def post(self, app_id: UUID, annotation_id: UUID):
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -371,11 +355,9 @@ class AnnotationBatchImportApi(Resource):
|
||||
@annotation_import_rate_limit
|
||||
@annotation_import_concurrency_limit
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
def post(self, app_id: UUID):
|
||||
from configs import dify_config
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@ -391,9 +373,9 @@ class AnnotationBatchImportApi(Resource):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
# Check file size before processing
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
file.stream.seek(0, 2) # Seek to end of file
|
||||
file_size = file.stream.tell()
|
||||
file.stream.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
@ -406,7 +388,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
if file_size == 0:
|
||||
raise ValueError("The uploaded file is empty")
|
||||
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
return AppAnnotationService.batch_import_app_annotations(str(app_id), file)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
@ -421,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id):
|
||||
job_id = str(job_id)
|
||||
def get(self, app_id: UUID, job_id: UUID):
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
@ -456,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id, annotation_id):
|
||||
def get(self, app_id: UUID, annotation_id: UUID):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
str(app_id), str(annotation_id), page, limit
|
||||
)
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
|
||||
@ -12,8 +12,9 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
@ -25,6 +26,7 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -32,12 +34,12 @@ from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -176,12 +178,6 @@ class AppTracePayload(BaseModel):
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -198,7 +194,7 @@ class WorkflowPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfigPartial(ResponseModel):
|
||||
@ -212,7 +208,7 @@ class ModelConfigPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
@ -273,7 +269,7 @@ class ModelConfig(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class Site(ResponseModel):
|
||||
@ -316,7 +312,7 @@ class Site(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DeletedTool(ResponseModel):
|
||||
@ -359,7 +355,7 @@ class AppPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetail(ResponseModel):
|
||||
@ -389,7 +385,7 @@ class AppDetail(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetailWithSite(AppDetail):
|
||||
@ -418,6 +414,7 @@ class AppExportResponse(ResponseModel):
|
||||
|
||||
|
||||
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
|
||||
register_response_schema_models(console_ns, RedirectUrlResponse, SimpleResultResponse)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -476,11 +473,18 @@ class AppListApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args_dict = args.model_dump()
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
mode=args.mode,
|
||||
name=args.name,
|
||||
tag_ids=args.tag_ids,
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -544,9 +548,17 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
mode=args.mode,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
app = app_service.create_app(current_tenant_id, params, current_user)
|
||||
app_detail = AppDetail.model_validate(app, from_attributes=True)
|
||||
return app_detail.model_dump(mode="json"), 201
|
||||
|
||||
@ -700,7 +712,7 @@ class AppExportApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
payload = AppExportResponse(
|
||||
data=AppDslService.export_dsl(
|
||||
@ -714,6 +726,7 @@ class AppExportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RedirectUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -839,9 +852,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@ -849,18 +864,23 @@ class AppTraceApi(Resource):
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trace configuration updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
app_id=app_model.id,
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -33,6 +33,7 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
register_enum_models(console_ns, ImportStatus)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class TextModesApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -37,7 +39,6 @@ from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
@ -65,13 +66,8 @@ class ChatMessagePayload(BaseMessagePayload):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@ -130,7 +126,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_completion_message")
|
||||
@console_ns.doc(description="Stop a running completion message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -211,7 +207,7 @@ class ChatMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_chat_message")
|
||||
@console_ns.doc(description="Stop a running chat message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -39,8 +39,6 @@ from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
@ -70,15 +68,6 @@ class ChatConversationQuery(BaseConversationQuery):
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CompletionConversationQuery,
|
||||
@ -89,6 +78,8 @@ register_schema_models(
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -107,7 +98,7 @@ class CompletionConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
@ -221,7 +212,7 @@ class ChatConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
|
||||
@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
@ -25,12 +26,6 @@ class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -65,7 +60,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
@ -100,7 +95,7 @@ class ConversationVariablesApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -19,13 +20,12 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
@ -41,16 +41,16 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ModelConfig)
|
||||
register_enum_models(console_ns, LLMMode)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RuleGeneratePayload,
|
||||
RuleCodeGeneratePayload,
|
||||
RuleStructuredOutputPayload,
|
||||
InstructionGeneratePayload,
|
||||
InstructionTemplatePayload,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
@ -30,12 +31,6 @@ class MCPServerUpdatePayload(BaseModel):
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class AppMCPServerResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -59,7 +54,7 @@ class AppMCPServerResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -37,10 +38,9 @@ from fields.conversation_fields import (
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@ -144,9 +144,7 @@ class MessageDetailResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
@ -165,6 +163,7 @@ register_schema_models(
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
@ -250,7 +249,7 @@ class MessageFeedbackApi(Resource):
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(200, "Feedback updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
|
||||
@ -5,14 +5,15 @@ from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.ops_service import OpsService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TraceProviderQuery(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
@ -23,13 +24,7 @@ class TraceConfigPayload(BaseModel):
|
||||
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TraceProviderQuery.__name__,
|
||||
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, TraceProviderQuery, TraceConfigPayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
@ -49,11 +44,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
@get_app_model
|
||||
def get(self, app_model: App):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
trace_config = OpsService.get_tracing_app_config(
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider
|
||||
)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@ -71,13 +69,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
"""Create a new trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@ -96,13 +95,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
@get_app_model
|
||||
def patch(self, app_model: App):
|
||||
"""Update an existing trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@ -119,12 +119,13 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
@get_app_model
|
||||
def delete(self, app_model: App):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -5,6 +5,7 @@ from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -15,8 +16,6 @@ from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
@ -30,10 +29,7 @@ class StatisticTimeRangeQuery(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, StatisticTimeRangeQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
@ -54,7 +50,7 @@ class DailyMessageStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -111,7 +107,7 @@ class DailyConversationStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -167,7 +163,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -224,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -284,7 +280,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -360,7 +356,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -426,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -482,7 +478,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
||||
@ -1,19 +1,25 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import AliasChoices, BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.fields import NewAppResponse, SimpleResultResponse
|
||||
from controllers.common.schema import (
|
||||
register_response_schema_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
@ -22,6 +28,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.helper import encrypter
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
@ -34,17 +41,18 @@ from core.trigger.debug.event_selectors import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import SimpleAccount
|
||||
from fields.workflow_run_fields import WorkflowRunNodeExecutionResponse
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.variables import SecretVariable, SegmentType, VariableBase
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -56,48 +64,22 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
from fields.workflow_fields import pipeline_variable_fields, serialize_value_type
|
||||
|
||||
conversation_variable_model = console_ns.model(
|
||||
"ConversationVariable",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"value_type": fields.String(attribute=serialize_value_type),
|
||||
"value": fields.Raw,
|
||||
"description": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
pipeline_variable_model = console_ns.model("PipelineVariable", pipeline_variable_fields)
|
||||
|
||||
# Workflow model with nested dependencies
|
||||
workflow_fields_copy = workflow_fields.copy()
|
||||
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
|
||||
workflow_fields_copy["updated_by"] = fields.Nested(
|
||||
simple_account_model, attribute="updated_by_account", allow_null=True
|
||||
)
|
||||
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
|
||||
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
|
||||
workflow_model = console_ns.model("Workflow", workflow_fields_copy)
|
||||
|
||||
# Workflow pagination model
|
||||
workflow_pagination_fields_copy = workflow_pagination_fields.copy()
|
||||
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
|
||||
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
|
||||
class EnvironmentVariableResponseDict(TypedDict):
|
||||
value_type: str
|
||||
id: NotRequired[str]
|
||||
name: NotRequired[str]
|
||||
value: NotRequired[Any]
|
||||
description: NotRequired[str | None]
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
@ -168,6 +150,110 @@ class WorkflowOnlineUsersPayload(BaseModel):
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
|
||||
|
||||
class WorkflowConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: Any = Field(json_schema_extra={"type": "object"})
|
||||
description: str
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _serialize_value_type(cls, value: Any) -> str:
|
||||
if hasattr(value, "exposed_type"):
|
||||
return str(value.exposed_type())
|
||||
return str(value)
|
||||
|
||||
|
||||
class PipelineVariableResponse(ResponseModel):
|
||||
label: str
|
||||
variable: str
|
||||
type: str
|
||||
belong_to_node_id: str
|
||||
max_length: int | None = None
|
||||
required: bool
|
||||
unit: str | None = None
|
||||
default_value: Any = Field(default=None, json_schema_extra={"type": "object"})
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
tooltips: str | None = None
|
||||
allowed_file_types: list[str] | None = None
|
||||
allowed_file_extensions: list[str] | None = Field(
|
||||
default=None, validation_alias=AliasChoices("allowed_file_extensions", "allow_file_extension")
|
||||
)
|
||||
allowed_file_upload_methods: list[str] | None = Field(
|
||||
default=None, validation_alias=AliasChoices("allowed_file_upload_methods", "allow_file_upload_methods")
|
||||
)
|
||||
|
||||
|
||||
class WorkflowEnvironmentVariableResponse(ResponseModel):
|
||||
value_type: str
|
||||
id: str
|
||||
name: str
|
||||
value: Any = Field(json_schema_extra={"type": "object"})
|
||||
description: str
|
||||
|
||||
|
||||
class WorkflowResponse(ResponseModel):
|
||||
id: str
|
||||
graph: dict[str, Any] = Field(validation_alias=AliasChoices("graph_dict", "graph"))
|
||||
features: dict[str, Any] = Field(validation_alias=AliasChoices("features_dict", "features"))
|
||||
hash: str = Field(validation_alias=AliasChoices("unique_hash", "hash"))
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
created_by: SimpleAccount | None = Field(
|
||||
default=None, validation_alias=AliasChoices("created_by_account", "created_by")
|
||||
)
|
||||
created_at: int
|
||||
updated_by: SimpleAccount | None = Field(
|
||||
default=None, validation_alias=AliasChoices("updated_by_account", "updated_by")
|
||||
)
|
||||
updated_at: int
|
||||
tool_published: bool
|
||||
environment_variables: list[WorkflowEnvironmentVariableResponse]
|
||||
conversation_variables: list[WorkflowConversationVariableResponse]
|
||||
rag_pipeline_variables: list[PipelineVariableResponse]
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int:
|
||||
timestamp = to_timestamp(value)
|
||||
if timestamp is None:
|
||||
raise ValueError("timestamp is required")
|
||||
return timestamp
|
||||
|
||||
@field_validator("environment_variables", mode="before")
|
||||
@classmethod
|
||||
def _serialize_environment_variables(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
return [_serialize_environment_variable(item) for item in value]
|
||||
|
||||
|
||||
class WorkflowPaginationResponse(ResponseModel):
|
||||
items: list[WorkflowResponse]
|
||||
page: int
|
||||
limit: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class WorkflowOnlineUser(ResponseModel):
|
||||
user_id: str
|
||||
username: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class WorkflowOnlineUsersByApp(ResponseModel):
|
||||
app_id: str
|
||||
users: list[WorkflowOnlineUser]
|
||||
|
||||
|
||||
class WorkflowOnlineUsersResponse(ResponseModel):
|
||||
data: list[WorkflowOnlineUsersByApp]
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -176,25 +262,38 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SyncDraftWorkflowPayload,
|
||||
AdvancedChatWorkflowRunPayload,
|
||||
IterationNodeRunPayload,
|
||||
LoopNodeRunPayload,
|
||||
DraftWorkflowRunPayload,
|
||||
DraftWorkflowNodeRunPayload,
|
||||
PublishWorkflowPayload,
|
||||
DefaultBlockConfigQuery,
|
||||
ConvertToWorkflowPayload,
|
||||
WorkflowListQuery,
|
||||
WorkflowUpdatePayload,
|
||||
WorkflowFeaturesPayload,
|
||||
WorkflowOnlineUsersPayload,
|
||||
DraftWorkflowTriggerRunPayload,
|
||||
DraftWorkflowTriggerRunAllPayload,
|
||||
)
|
||||
register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowConversationVariableResponse,
|
||||
PipelineVariableResponse,
|
||||
WorkflowEnvironmentVariableResponse,
|
||||
WorkflowResponse,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowOnlineUser,
|
||||
WorkflowOnlineUsersByApp,
|
||||
WorkflowOnlineUsersResponse,
|
||||
NewAppResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -216,18 +315,56 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
|
||||
return file_objs
|
||||
|
||||
|
||||
def _serialize_environment_variable(value: Any) -> EnvironmentVariableResponseDict | Any:
|
||||
match value:
|
||||
case SecretVariable():
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": encrypter.full_mask_token(),
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
|
||||
case VariableBase():
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": str(value.value_type.exposed_type()),
|
||||
"description": value.description,
|
||||
}
|
||||
|
||||
case dict():
|
||||
value_type_str = value.get("value_type")
|
||||
if not isinstance(value_type_str, str):
|
||||
raise TypeError(
|
||||
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
|
||||
)
|
||||
value_type = SegmentType(value_type_str).exposed_type()
|
||||
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
|
||||
raise ValueError(f"Unsupported environment variable value type: {value_type}")
|
||||
return value
|
||||
|
||||
case _:
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft")
|
||||
class DraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_draft_workflow")
|
||||
@console_ns.doc(description="Get draft workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow retrieved successfully",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -240,8 +377,8 @@ class DraftWorkflowApi(Resource):
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# return workflow, if not found, return None (initiate graph by frontend)
|
||||
return workflow
|
||||
# return workflow, if not found, return 404
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -540,9 +677,12 @@ class HumanInputDeliveryTestPayload(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
reg(HumanInputFormPreviewPayload)
|
||||
reg(HumanInputFormSubmitPayload)
|
||||
reg(HumanInputDeliveryTestPayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HumanInputFormPreviewPayload,
|
||||
HumanInputFormSubmitPayload,
|
||||
HumanInputDeliveryTestPayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
@ -732,7 +872,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_workflow_task")
|
||||
@console_ns.doc(description="Stop running workflow task")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(404, "Task not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -760,14 +900,17 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node run started successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@ -799,7 +942,9 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(
|
||||
workflow_node_execution, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
@ -807,13 +952,15 @@ class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
@console_ns.doc(description="Get published workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Published workflow not found")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflow retrieved successfully, or null if not found",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -824,7 +971,10 @@ class PublishedWorkflowApi(Resource):
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
if workflow is None:
|
||||
return None
|
||||
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@setup_required
|
||||
@ -902,7 +1052,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
filters = None
|
||||
if args.q:
|
||||
@ -922,7 +1072,11 @@ class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Application converted to workflow successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Application converted to workflow successfully",
|
||||
console_ns.models[NewAppResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Application cannot be converted")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -959,7 +1113,11 @@ class WorkflowFeaturesApi(Resource):
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow features updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -983,7 +1141,11 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflows retrieved successfully",
|
||||
console_ns.models[WorkflowPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -995,7 +1157,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
@ -1015,14 +1177,14 @@ class PublishedAllWorkflowApi(Resource):
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_fields_copy)
|
||||
|
||||
return {
|
||||
"items": serialized_workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
|
||||
@ -1068,14 +1230,13 @@ class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(200, "Workflow updated successfully", console_ns.models[WorkflowResponse.__name__])
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@ -1109,7 +1270,7 @@ class WorkflowByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
return workflow
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1143,14 +1304,17 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_draft_workflow_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node last run retrieved successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Node last run not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
workflow = srv.get_draft_workflow(app_model)
|
||||
@ -1163,7 +1327,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
|
||||
@ -1391,12 +1555,16 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow online users retrieved successfully",
|
||||
console_ns.models[WorkflowOnlineUsersResponse.__name__],
|
||||
)
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -1439,10 +1607,18 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
if not isinstance(user_info, dict):
|
||||
continue
|
||||
|
||||
user_id = user_info.get("user_id")
|
||||
username = user_info.get("username")
|
||||
if not isinstance(user_id, str) or not isinstance(username, str):
|
||||
continue
|
||||
|
||||
avatar = user_info.get("avatar")
|
||||
if avatar is not None and not isinstance(avatar, str):
|
||||
avatar = None
|
||||
|
||||
if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")):
|
||||
try:
|
||||
user_info["avatar"] = file_helpers.get_signed_file_url(avatar)
|
||||
avatar = file_helpers.get_signed_file_url(avatar)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to sign workflow online user avatar; using original value. "
|
||||
@ -1452,7 +1628,7 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
exc,
|
||||
)
|
||||
|
||||
users.append(user_info)
|
||||
users.append({"user_id": user_id, "username": username, "avatar": avatar})
|
||||
results.append({"app_id": app_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
return WorkflowOnlineUsersResponse.model_validate({"data": results}).model_dump(mode="json")
|
||||
|
||||
@ -16,6 +16,7 @@ from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -82,9 +83,7 @@ class WorkflowRunForLogResponse(ResponseModel):
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
@ -117,9 +116,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@ -133,9 +130,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
@ -185,7 +180,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@ -228,7 +223,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
|
||||
@ -1,29 +1,22 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
@ -52,24 +45,159 @@ class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
for model in (
|
||||
class WorkflowCommentAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
avatar: str | None = Field(default=None, exclude=True)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class WorkflowCommentReply(ResponseModel):
|
||||
id: str
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentMention(ResponseModel):
|
||||
mentioned_user_id: str
|
||||
mentioned_user_account: WorkflowCommentAccount | None = None
|
||||
reply_id: str | None = None
|
||||
|
||||
|
||||
class WorkflowCommentBasic(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
reply_count: int
|
||||
mention_count: int
|
||||
participants: list[WorkflowCommentAccount]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentBasicList(ResponseModel):
|
||||
data: list[WorkflowCommentBasic]
|
||||
|
||||
|
||||
class WorkflowCommentDetail(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
replies: list[WorkflowCommentReply]
|
||||
mentions: list[WorkflowCommentMention]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentResolve(ResponseModel):
|
||||
id: str
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
@field_validator("resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountWithRole,
|
||||
WorkflowCommentMentionUsersPayload,
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowCommentAccount,
|
||||
WorkflowCommentReply,
|
||||
WorkflowCommentMention,
|
||||
WorkflowCommentBasic,
|
||||
WorkflowCommentBasicList,
|
||||
WorkflowCommentDetail,
|
||||
WorkflowCommentCreate,
|
||||
WorkflowCommentUpdate,
|
||||
WorkflowCommentResolve,
|
||||
WorkflowCommentReplyCreate,
|
||||
WorkflowCommentReplyUpdate,
|
||||
)
|
||||
|
||||
|
||||
@ -80,28 +208,26 @@ class WorkflowCommentListApi(Resource):
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@console_ns.response(200, "Comments retrieved successfully", console_ns.models[WorkflowCommentBasicList.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@console_ns.response(201, "Comment created successfully", console_ns.models[WorkflowCommentCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
@ -117,7 +243,7 @@ class WorkflowCommentListApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
@ -127,30 +253,28 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@console_ns.response(200, "Comment retrieved successfully", console_ns.models[WorkflowCommentDetail.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@console_ns.response(200, "Comment updated successfully", console_ns.models[WorkflowCommentUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
@ -167,7 +291,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
return dump_response(WorkflowCommentUpdate, result)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@ -197,12 +321,11 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@console_ns.response(200, "Comment resolved successfully", console_ns.models[WorkflowCommentResolve.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
@ -213,7 +336,7 @@ class WorkflowCommentResolveApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentResolve, comment)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
@ -224,12 +347,11 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@console_ns.response(201, "Reply created successfully", console_ns.models[WorkflowCommentReplyCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
@ -247,7 +369,7 @@ class WorkflowCommentReplyApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentReplyCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
@ -258,12 +380,11 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@console_ns.response(200, "Reply updated successfully", console_ns.models[WorkflowCommentReplyUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
@ -284,7 +405,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
return dump_response(WorkflowCommentReplyUpdate, reply)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
|
||||
@ -8,6 +8,7 @@ from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
@ -33,7 +34,6 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowDraftVariableListQuery(BaseModel):
|
||||
@ -56,21 +56,12 @@ class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
ConversationVariableUpdatePayload,
|
||||
EnvironmentVariableUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@ -260,7 +251,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -1,30 +1,28 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, TypedDict, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_for_list_fields,
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_for_list_fields,
|
||||
workflow_run_node_execution_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
AdvancedChatWorkflowRunPaginationResponse,
|
||||
WorkflowRunCountResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
)
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
@ -52,82 +50,6 @@ def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
# Models that depend on simple_account_fields
|
||||
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
|
||||
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
|
||||
|
||||
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
|
||||
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
advanced_chat_workflow_run_for_list_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
|
||||
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
|
||||
|
||||
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
|
||||
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_model = console_ns.model(
|
||||
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
|
||||
)
|
||||
|
||||
# Simple models without nested dependencies
|
||||
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
|
||||
|
||||
# Pagination models that depend on list models
|
||||
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
|
||||
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
|
||||
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
|
||||
)
|
||||
advanced_chat_workflow_run_pagination_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
|
||||
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
|
||||
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
|
||||
|
||||
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
|
||||
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
|
||||
workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_export_fields = console_ns.model(
|
||||
"WorkflowRunExport",
|
||||
{
|
||||
"status": fields.String(description="Export status: success/failed"),
|
||||
"presigned_url": fields.String(description="Pre-signed URL for download", required=False),
|
||||
"presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
|
||||
},
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
@ -136,7 +58,7 @@ class WorkflowRunListQuery(BaseModel):
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
default=None, description="Filter by trigger source: debugging or app-run. Default: debugging"
|
||||
)
|
||||
|
||||
@field_validator("last_id")
|
||||
@ -151,9 +73,15 @@ class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
time_range: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
),
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
default=None, description="Filter by trigger source: debugging or app-run. Default: debugging"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@ -164,56 +92,69 @@ class WorkflowRunCountQuery(BaseModel):
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class WorkflowRunExportResponse(ResponseModel):
|
||||
status: str = Field(description="Export status: success/failed")
|
||||
presigned_url: str | None = Field(default=None, description="Pre-signed URL for download")
|
||||
presigned_url_expires_at: str | None = Field(default=None, description="Pre-signed URL expiration time")
|
||||
|
||||
|
||||
class HumanInputPauseTypeResponse(TypedDict):
|
||||
class HumanInputPauseTypeResponse(ResponseModel):
|
||||
type: Literal["human_input"]
|
||||
form_id: str
|
||||
backstage_input_url: str | None
|
||||
backstage_input_url: str | None = None
|
||||
|
||||
|
||||
class PausedNodeResponse(TypedDict):
|
||||
class PausedNodeResponse(ResponseModel):
|
||||
node_id: str
|
||||
node_title: str
|
||||
pause_type: HumanInputPauseTypeResponse
|
||||
|
||||
|
||||
class WorkflowPauseDetailsResponse(TypedDict):
|
||||
paused_at: str | None
|
||||
class WorkflowPauseDetailsResponse(ResponseModel):
|
||||
paused_at: str | None = None
|
||||
paused_nodes: list[PausedNodeResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowRunListQuery,
|
||||
WorkflowRunCountQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AdvancedChatWorkflowRunPaginationResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
WorkflowRunCountResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunExportResponse,
|
||||
HumanInputPauseTypeResponse,
|
||||
PausedNodeResponse,
|
||||
WorkflowPauseDetailsResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs")
|
||||
@console_ns.doc(description="Get advanced chat workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
@console_ns.doc(params=query_params_from_model(WorkflowRunListQuery))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs retrieved successfully",
|
||||
console_ns.models[AdvancedChatWorkflowRunPaginationResponse.__name__],
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
@ -232,7 +173,9 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
return AdvancedChatWorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/export")
|
||||
@ -240,7 +183,7 @@ class WorkflowRunExportApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_export_url")
|
||||
@console_ns.doc(description="Generate a download URL for an archived workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Export URL generated", workflow_run_export_fields)
|
||||
@console_ns.response(200, "Export URL generated", console_ns.models[WorkflowRunExportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -278,11 +221,14 @@ class WorkflowRunExportApi(Resource):
|
||||
expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
|
||||
return {
|
||||
"status": "success",
|
||||
"presigned_url": presigned_url,
|
||||
"presigned_url_expires_at": expires_at.isoformat(),
|
||||
}, 200
|
||||
response = WorkflowRunExportResponse.model_validate(
|
||||
{
|
||||
"status": "success",
|
||||
"presigned_url": presigned_url,
|
||||
"presigned_url_expires_at": expires_at.isoformat(),
|
||||
}
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
@ -290,32 +236,21 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
@console_ns.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
@console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs count retrieved successfully",
|
||||
console_ns.models[WorkflowRunCountResponse.__name__],
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
@ -333,7 +268,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs")
|
||||
@ -341,25 +276,21 @@ class WorkflowRunListApi(Resource):
|
||||
@console_ns.doc("get_workflow_runs")
|
||||
@console_ns.doc(description="Get workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
@console_ns.doc(params=query_params_from_model(WorkflowRunListQuery))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs retrieved successfully",
|
||||
console_ns.models[WorkflowRunPaginationResponse.__name__],
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
@ -378,7 +309,7 @@ class WorkflowRunListApi(Resource):
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
@ -386,32 +317,21 @@ class WorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_workflow_runs_count")
|
||||
@console_ns.doc(description="Get workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
@console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs count retrieved successfully",
|
||||
console_ns.models[WorkflowRunCountResponse.__name__],
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
@ -429,7 +349,7 @@ class WorkflowRunCountApi(Resource):
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
@ -437,13 +357,16 @@ class WorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_detail")
|
||||
@console_ns.doc(description="Get workflow run detail")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow run detail retrieved successfully",
|
||||
console_ns.models[WorkflowRunDetailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
@ -452,8 +375,10 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
return workflow_run
|
||||
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
@ -461,13 +386,16 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_node_executions")
|
||||
@console_ns.doc(description="Get workflow run node execution list")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node executions retrieved successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
@ -482,13 +410,24 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
{"data": node_executions}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@console_ns.doc("get_workflow_pause_details")
|
||||
@console_ns.doc(description="Get workflow pause details")
|
||||
@console_ns.doc(params={"workflow_run_id": "Workflow run ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow pause details retrieved successfully",
|
||||
console_ns.models[WorkflowPauseDetailsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -515,11 +454,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
empty_response: WorkflowPauseDetailsResponse = {
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
}
|
||||
return empty_response, 200
|
||||
empty_response = WorkflowPauseDetailsResponse(paused_at=None, paused_nodes=[])
|
||||
return empty_response.model_dump(mode="json"), 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
@ -530,27 +466,25 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
paused_nodes: list[PausedNodeResponse] = []
|
||||
response: WorkflowPauseDetailsResponse = {
|
||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
||||
"paused_nodes": paused_nodes,
|
||||
}
|
||||
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
paused_nodes.append(
|
||||
{
|
||||
"node_id": reason.node_id,
|
||||
"node_title": reason.node_title,
|
||||
"pause_type": {
|
||||
"type": "human_input",
|
||||
"form_id": reason.form_id,
|
||||
"backstage_input_url": _build_backstage_input_url(
|
||||
form_tokens_by_form_id.get(reason.form_id)
|
||||
),
|
||||
},
|
||||
}
|
||||
PausedNodeResponse(
|
||||
node_id=reason.node_id,
|
||||
node_title=reason.node_title,
|
||||
pause_type=HumanInputPauseTypeResponse(
|
||||
type="human_input",
|
||||
form_id=reason.form_id,
|
||||
backstage_input_url=_build_backstage_input_url(form_tokens_by_form_id.get(reason.form_id)),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AssertionError("unimplemented.")
|
||||
|
||||
return response, 200
|
||||
response = WorkflowPauseDetailsResponse(
|
||||
paused_at=paused_at.isoformat() + "Z" if paused_at else None,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -3,6 +3,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -13,8 +14,6 @@ from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
@ -28,10 +27,7 @@ class WorkflowStatisticQuery(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, WorkflowStatisticQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
@ -53,7 +49,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -93,7 +89,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -133,7 +129,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -173,7 +169,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ class WebhookTriggerApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -40,16 +38,29 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
class ActivationCheckData(BaseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
email: str | None
|
||||
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: ActivationCheckData | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ActivateCheckQuery,
|
||||
ActivatePayload,
|
||||
ActivationCheckData,
|
||||
ActivationCheckResponse,
|
||||
ActivationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
@ -63,7 +74,7 @@ class ActivateCheckApi(Resource):
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
token = args.token
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
@ -8,8 +10,6 @@ from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
@ -17,14 +17,26 @@ class ApiKeyAuthBindingPayload(BaseModel):
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class ApiKeyAuthDataSourceItem(ResponseModel):
|
||||
id: str
|
||||
category: str
|
||||
provider: str
|
||||
disabled: bool
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceListResponse(ResponseModel):
|
||||
sources: list[ApiKeyAuthDataSourceItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyAuthBindingPayload)
|
||||
register_response_schema_models(console_ns, ApiKeyAuthDataSourceItem, ApiKeyAuthDataSourceListResponse)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ApiKeyAuthDataSourceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -74,6 +86,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -3,7 +3,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from constants.languages import get_valid_language, languages
|
||||
from controllers.common.fields import SimpleResultDataResponse, VerificationTokenResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -14,17 +16,16 @@ from controllers.console.auth.error import (
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
@ -41,15 +42,24 @@ class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return validate_timezone_string(value)
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultDataResponse, VerificationTokenResponse)
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
@ -57,6 +67,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -81,6 +92,7 @@ class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[VerificationTokenResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
@ -146,26 +158,32 @@ class EmailRegisterResetApi(Resource):
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
account = self._create_new_account(
|
||||
email=normalized_email,
|
||||
password=args.password_confirm,
|
||||
timezone=args.timezone,
|
||||
language=args.language,
|
||||
)
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
def _create_new_account(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> Account:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
return AccountService.create_account_and_tenant(
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=timezone,
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
@ -28,8 +28,6 @@ from services.entities.auth_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
@ -3,12 +3,14 @@ import logging
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.common.fields import SimpleResultDataResponse, SimpleResultOptionalDataResponse, SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
@ -33,6 +35,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
@ -50,7 +53,6 @@ from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -69,15 +71,23 @@ class EmailCodeLoginPayload(BaseModel):
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return validate_timezone_string(value)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultOptionalDataResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
@ -87,6 +97,7 @@ class LoginApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultOptionalDataResponse.__name__])
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
@ -160,6 +171,7 @@ class LoginApi(Resource):
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@ -183,6 +195,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -210,6 +223,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -242,6 +256,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
@ -294,6 +309,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=args.timezone,
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -317,6 +333,7 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
@ -53,6 +54,31 @@ def get_oauth_providers():
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
def _validated_timezone(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return validate_timezone_string(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _validated_language(value: str | None) -> str | None:
|
||||
if value and value in languages:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _preferred_interface_language(language: str | None = None) -> str:
|
||||
if language:
|
||||
return language
|
||||
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
return preferred_lang
|
||||
return languages[0]
|
||||
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@console_ns.doc("oauth_login")
|
||||
@ -64,13 +90,19 @@ class OAuthLogin(Resource):
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
timezone = _validated_timezone(request.args.get("timezone") or None)
|
||||
language = _validated_language(request.args.get("language") or None)
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
|
||||
auth_url = oauth_provider.get_authorization_url(
|
||||
invite_token=invite_token,
|
||||
timezone=timezone,
|
||||
language=language,
|
||||
)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@ -96,9 +128,10 @@ class OAuthCallback(Resource):
|
||||
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
invite_token = None
|
||||
if state:
|
||||
invite_token = state
|
||||
oauth_state = decode_oauth_state(state)
|
||||
invite_token = oauth_state.get("invite_token")
|
||||
timezone = _validated_timezone(oauth_state.get("timezone"))
|
||||
language = _validated_language(oauth_state.get("language"))
|
||||
|
||||
if not code:
|
||||
return {"error": "Authorization code is required"}, 400
|
||||
@ -129,7 +162,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
def _generate_account(
|
||||
provider: str,
|
||||
user_info: OAuthUserInfo,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
@ -211,26 +249,19 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
"30 days and is temporarily unavailable for new account registration"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
interface_language = _preferred_interface_language(language)
|
||||
account = RegisterService.register(
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
interface_language = preferred_lang
|
||||
else:
|
||||
interface_language = languages[0]
|
||||
account.interface_language = interface_language
|
||||
db.session.commit()
|
||||
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_model
|
||||
from controllers.common.fields import SimpleResultResponse, TextContentResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@ -54,6 +55,7 @@ class DataSourceNotionPreviewQuery(BaseModel):
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
@ -157,6 +159,7 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
@ -289,6 +292,7 @@ class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -362,6 +366,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -379,6 +384,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
@ -8,7 +8,8 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -58,6 +59,8 @@ from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
@ -521,6 +524,7 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -543,7 +547,11 @@ class DatasetUseCheckApi(Resource):
|
||||
@console_ns.doc("check_dataset_use")
|
||||
@console_ns.doc(description="Check if dataset is in use")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset use status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset use status retrieved successfully",
|
||||
console_ns.models[UsageCheckResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -606,63 +614,63 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
match args["info_list"]["data_source_type"]:
|
||||
case "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
@ -873,6 +881,7 @@ class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
@ -885,7 +894,7 @@ class DatasetEnableApiApi(Resource):
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@console_ns.doc("get_dataset_api_base_info")
|
||||
@console_ns.doc(description="Get dataset API base information")
|
||||
@console_ns.response(200, "API base info retrieved successfully")
|
||||
@console_ns.response(200, "API base info retrieved successfully", console_ns.models[ApiBaseUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -15,7 +15,8 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultMessageResponse, SimpleResultResponse, UrlResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
@ -39,6 +40,7 @@ from fields.document_fields import (
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
@ -71,12 +73,6 @@ from ..wraps import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
@ -101,7 +97,7 @@ class DatasetResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
@ -152,7 +148,7 @@ class DocumentResponse(ResponseModel):
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
@ -209,6 +205,7 @@ register_schema_models(
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
@ -369,28 +366,31 @@ class DatasetDocumentListApi(Resource):
|
||||
else:
|
||||
sort_logic = asc
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
match sort:
|
||||
case "hit_count":
|
||||
sub_query = (
|
||||
sa.select(
|
||||
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
|
||||
)
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case _:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
@ -489,6 +489,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Documents deleted successfully")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -948,6 +949,7 @@ class DocumentApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -973,6 +975,7 @@ class DocumentDownloadApi(DocumentResource):
|
||||
|
||||
@console_ns.doc("get_dataset_document_download_url")
|
||||
@console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
|
||||
@console_ns.response(200, "Download URL generated successfully", console_ns.models[UrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -1030,7 +1033,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@console_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
|
||||
)
|
||||
@console_ns.response(200, "Processing status updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Processing status updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Invalid action")
|
||||
@setup_required
|
||||
@ -1075,7 +1082,11 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Document metadata updated successfully",
|
||||
console_ns.models[SimpleResultMessageResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -1129,6 +1140,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1166,6 +1178,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document paused successfully")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1200,6 +1213,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document resumed successfully")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1232,6 +1246,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
@console_ns.response(204, "Documents retry started successfully")
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
@ -1298,6 +1313,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -1364,7 +1380,11 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Summary generation started successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
|
||||
@ -10,7 +10,8 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@ -30,6 +31,7 @@ from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import escape_like_pattern
|
||||
@ -83,6 +85,11 @@ class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class SegmentBatchImportStatusResponse(ResponseModel):
|
||||
job_id: str
|
||||
job_status: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
@ -98,6 +105,7 @@ register_schema_models(
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
@ -217,6 +225,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@ -252,6 +261,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -424,6 +434,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -464,6 +475,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@console_ns.response(200, "Batch import started", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -514,6 +526,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
|
||||
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -691,6 +704,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import UsageCountResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -27,6 +28,8 @@ from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
|
||||
|
||||
register_response_schema_models(console_ns, UsageCountResponse)
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
@ -206,6 +209,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
@ -222,7 +226,7 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@console_ns.doc("check_external_api_usage")
|
||||
@console_ns.doc(description="Check if external knowledge API is being used")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "Usage check completed successfully")
|
||||
@console_ns.response(200, "Usage check completed successfully", console_ns.models[UsageCountResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -19,12 +20,6 @@ from ..wraps import (
|
||||
)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
@ -61,7 +56,7 @@ class HitTestingSegment(ResponseModel):
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
|
||||
@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
raise ValueError("Invalid hit testing records response")
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
raise ValueError("Invalid hit testing record response")
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
|
||||
@ -4,7 +4,8 @@ from flask_restx import Resource, marshal_with
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
@ -21,6 +22,7 @@ from services.metadata_service import MetadataService
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -83,6 +85,7 @@ class DatasetMetadataApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -113,6 +116,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -136,6 +140,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user