mirror of
https://github.com/langgenius/dify.git
synced 2026-01-28 15:56:00 +08:00
Compare commits
2 Commits
fix/parall
...
refactor/s
| Author | SHA1 | Date | |
|---|---|---|---|
| d8f4eddc4c | |||
| 7b56f16255 |
@ -46,30 +46,23 @@ Comprehensive performance optimization guide for React and Next.js applications,
|
||||
- 4.3 [Use SWR for Automatic Deduplication](#43-use-swr-for-automatic-deduplication)
|
||||
- 4.4 [Version and Minimize localStorage Data](#44-version-and-minimize-localstorage-data)
|
||||
5. [Re-render Optimization](#5-re-render-optimization) — **MEDIUM**
|
||||
- 5.1 [Calculate Derived State During Rendering](#51-calculate-derived-state-during-rendering)
|
||||
- 5.2 [Defer State Reads to Usage Point](#52-defer-state-reads-to-usage-point)
|
||||
- 5.3 [Do not wrap a simple expression with a primitive result type in useMemo](#53-do-not-wrap-a-simple-expression-with-a-primitive-result-type-in-usememo)
|
||||
- 5.4 [Extract Default Non-primitive Parameter Value from Memoized Component to Constant](#54-extract-default-non-primitive-parameter-value-from-memoized-component-to-constant)
|
||||
- 5.5 [Extract to Memoized Components](#55-extract-to-memoized-components)
|
||||
- 5.6 [Narrow Effect Dependencies](#56-narrow-effect-dependencies)
|
||||
- 5.7 [Put Interaction Logic in Event Handlers](#57-put-interaction-logic-in-event-handlers)
|
||||
- 5.8 [Subscribe to Derived State](#58-subscribe-to-derived-state)
|
||||
- 5.9 [Use Functional setState Updates](#59-use-functional-setstate-updates)
|
||||
- 5.10 [Use Lazy State Initialization](#510-use-lazy-state-initialization)
|
||||
- 5.11 [Use Transitions for Non-Urgent Updates](#511-use-transitions-for-non-urgent-updates)
|
||||
- 5.12 [Use useRef for Transient Values](#512-use-useref-for-transient-values)
|
||||
- 5.1 [Defer State Reads to Usage Point](#51-defer-state-reads-to-usage-point)
|
||||
- 5.2 [Extract to Memoized Components](#52-extract-to-memoized-components)
|
||||
- 5.3 [Narrow Effect Dependencies](#53-narrow-effect-dependencies)
|
||||
- 5.4 [Subscribe to Derived State](#54-subscribe-to-derived-state)
|
||||
- 5.5 [Use Functional setState Updates](#55-use-functional-setstate-updates)
|
||||
- 5.6 [Use Lazy State Initialization](#56-use-lazy-state-initialization)
|
||||
- 5.7 [Use Transitions for Non-Urgent Updates](#57-use-transitions-for-non-urgent-updates)
|
||||
6. [Rendering Performance](#6-rendering-performance) — **MEDIUM**
|
||||
- 6.1 [Animate SVG Wrapper Instead of SVG Element](#61-animate-svg-wrapper-instead-of-svg-element)
|
||||
- 6.2 [CSS content-visibility for Long Lists](#62-css-content-visibility-for-long-lists)
|
||||
- 6.3 [Hoist Static JSX Elements](#63-hoist-static-jsx-elements)
|
||||
- 6.4 [Optimize SVG Precision](#64-optimize-svg-precision)
|
||||
- 6.5 [Prevent Hydration Mismatch Without Flickering](#65-prevent-hydration-mismatch-without-flickering)
|
||||
- 6.6 [Suppress Expected Hydration Mismatches](#66-suppress-expected-hydration-mismatches)
|
||||
- 6.7 [Use Activity Component for Show/Hide](#67-use-activity-component-for-showhide)
|
||||
- 6.8 [Use Explicit Conditional Rendering](#68-use-explicit-conditional-rendering)
|
||||
- 6.9 [Use useTransition Over Manual Loading States](#69-use-usetransition-over-manual-loading-states)
|
||||
- 6.6 [Use Activity Component for Show/Hide](#66-use-activity-component-for-showhide)
|
||||
- 6.7 [Use Explicit Conditional Rendering](#67-use-explicit-conditional-rendering)
|
||||
7. [JavaScript Performance](#7-javascript-performance) — **LOW-MEDIUM**
|
||||
- 7.1 [Avoid Layout Thrashing](#71-avoid-layout-thrashing)
|
||||
- 7.1 [Batch DOM CSS Changes](#71-batch-dom-css-changes)
|
||||
- 7.2 [Build Index Maps for Repeated Lookups](#72-build-index-maps-for-repeated-lookups)
|
||||
- 7.3 [Cache Property Access in Loops](#73-cache-property-access-in-loops)
|
||||
- 7.4 [Cache Repeated Function Calls](#74-cache-repeated-function-calls)
|
||||
@ -82,9 +75,8 @@ Comprehensive performance optimization guide for React and Next.js applications,
|
||||
- 7.11 [Use Set/Map for O(1) Lookups](#711-use-setmap-for-o1-lookups)
|
||||
- 7.12 [Use toSorted() Instead of sort() for Immutability](#712-use-tosorted-instead-of-sort-for-immutability)
|
||||
8. [Advanced Patterns](#8-advanced-patterns) — **LOW**
|
||||
- 8.1 [Initialize App Once, Not Per Mount](#81-initialize-app-once-not-per-mount)
|
||||
- 8.2 [Store Event Handlers in Refs](#82-store-event-handlers-in-refs)
|
||||
- 8.3 [useEffectEvent for Stable Callback Refs](#83-useeffectevent-for-stable-callback-refs)
|
||||
- 8.1 [Store Event Handlers in Refs](#81-store-event-handlers-in-refs)
|
||||
- 8.2 [useLatest for Stable Callback Refs](#82-uselatest-for-stable-callback-refs)
|
||||
|
||||
---
|
||||
|
||||
@ -200,21 +192,6 @@ const { user, config, profile } = await all({
|
||||
})
|
||||
```
|
||||
|
||||
**Alternative without extra dependencies:**
|
||||
|
||||
```typescript
|
||||
const userPromise = fetchUser()
|
||||
const profilePromise = userPromise.then(user => fetchProfile(user.id))
|
||||
|
||||
const [user, config, profile] = await Promise.all([
|
||||
userPromise,
|
||||
fetchConfig(),
|
||||
profilePromise
|
||||
])
|
||||
```
|
||||
|
||||
We can also create all the promises first, and do `Promise.all()` at the end.
|
||||
|
||||
Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
|
||||
|
||||
### 1.3 Prevent Waterfall Chains in API Routes
|
||||
@ -1283,43 +1260,7 @@ function cachePrefs(user: FullUser) {
|
||||
|
||||
Reducing unnecessary re-renders minimizes wasted computation and improves UI responsiveness.
|
||||
|
||||
### 5.1 Calculate Derived State During Rendering
|
||||
|
||||
**Impact: MEDIUM (avoids redundant renders and state drift)**
|
||||
|
||||
If a value can be computed from current props/state, do not store it in state or update it in an effect. Derive it during render to avoid extra renders and state drift. Do not set state in effects solely in response to prop changes; prefer derived values or keyed resets instead.
|
||||
|
||||
**Incorrect: redundant state and effect**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [firstName, setFirstName] = useState('First')
|
||||
const [lastName, setLastName] = useState('Last')
|
||||
const [fullName, setFullName] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
setFullName(firstName + ' ' + lastName)
|
||||
}, [firstName, lastName])
|
||||
|
||||
return <p>{fullName}</p>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: derive during render**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [firstName, setFirstName] = useState('First')
|
||||
const [lastName, setLastName] = useState('Last')
|
||||
const fullName = firstName + ' ' + lastName
|
||||
|
||||
return <p>{fullName}</p>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://react.dev/learn/you-might-not-need-an-effect](https://react.dev/learn/you-might-not-need-an-effect)
|
||||
|
||||
### 5.2 Defer State Reads to Usage Point
|
||||
### 5.1 Defer State Reads to Usage Point
|
||||
|
||||
**Impact: MEDIUM (avoids unnecessary subscriptions)**
|
||||
|
||||
@ -1354,71 +1295,7 @@ function ShareButton({ chatId }: { chatId: string }) {
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 Do not wrap a simple expression with a primitive result type in useMemo
|
||||
|
||||
**Impact: LOW-MEDIUM (wasted computation on every render)**
|
||||
|
||||
When an expression is simple (few logical or arithmetical operators) and has a primitive result type (boolean, number, string), do not wrap it in `useMemo`.
|
||||
|
||||
Calling `useMemo` and comparing hook dependencies may consume more resources than the expression itself.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
function Header({ user, notifications }: Props) {
|
||||
const isLoading = useMemo(() => {
|
||||
return user.isLoading || notifications.isLoading
|
||||
}, [user.isLoading, notifications.isLoading])
|
||||
|
||||
if (isLoading) return <Skeleton />
|
||||
// return some markup
|
||||
}
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
function Header({ user, notifications }: Props) {
|
||||
const isLoading = user.isLoading || notifications.isLoading
|
||||
|
||||
if (isLoading) return <Skeleton />
|
||||
// return some markup
|
||||
}
|
||||
```
|
||||
|
||||
### 5.4 Extract Default Non-primitive Parameter Value from Memoized Component to Constant
|
||||
|
||||
**Impact: MEDIUM (restores memoization by using a constant for default value)**
|
||||
|
||||
When memoized component has a default value for some non-primitive optional parameter, such as an array, function, or object, calling the component without that parameter results in broken memoization. This is because new value instances are created on every rerender, and they do not pass strict equality comparison in `memo()`.
|
||||
|
||||
To address this issue, extract the default value into a constant.
|
||||
|
||||
**Incorrect: `onClick` has different values on every rerender**
|
||||
|
||||
```tsx
|
||||
const UserAvatar = memo(function UserAvatar({ onClick = () => {} }: { onClick?: () => void }) {
|
||||
// ...
|
||||
})
|
||||
|
||||
// Used without optional onClick
|
||||
<UserAvatar />
|
||||
```
|
||||
|
||||
**Correct: stable default value**
|
||||
|
||||
```tsx
|
||||
const NOOP = () => {};
|
||||
|
||||
const UserAvatar = memo(function UserAvatar({ onClick = NOOP }: { onClick?: () => void }) {
|
||||
// ...
|
||||
})
|
||||
|
||||
// Used without optional onClick
|
||||
<UserAvatar />
|
||||
```
|
||||
|
||||
### 5.5 Extract to Memoized Components
|
||||
### 5.2 Extract to Memoized Components
|
||||
|
||||
**Impact: MEDIUM (enables early returns)**
|
||||
|
||||
@ -1458,7 +1335,7 @@ function Profile({ user, loading }: Props) {
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
|
||||
|
||||
### 5.6 Narrow Effect Dependencies
|
||||
### 5.3 Narrow Effect Dependencies
|
||||
|
||||
**Impact: LOW (minimizes effect re-runs)**
|
||||
|
||||
@ -1499,48 +1376,7 @@ useEffect(() => {
|
||||
}, [isMobile])
|
||||
```
|
||||
|
||||
### 5.7 Put Interaction Logic in Event Handlers
|
||||
|
||||
**Impact: MEDIUM (avoids effect re-runs and duplicate side effects)**
|
||||
|
||||
If a side effect is triggered by a specific user action (submit, click, drag), run it in that event handler. Do not model the action as state + effect; it makes effects re-run on unrelated changes and can duplicate the action.
|
||||
|
||||
**Incorrect: event modeled as state + effect**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [submitted, setSubmitted] = useState(false)
|
||||
const theme = useContext(ThemeContext)
|
||||
|
||||
useEffect(() => {
|
||||
if (submitted) {
|
||||
post('/api/register')
|
||||
showToast('Registered', theme)
|
||||
}
|
||||
}, [submitted, theme])
|
||||
|
||||
return <button onClick={() => setSubmitted(true)}>Submit</button>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: do it in the handler**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const theme = useContext(ThemeContext)
|
||||
|
||||
function handleSubmit() {
|
||||
post('/api/register')
|
||||
showToast('Registered', theme)
|
||||
}
|
||||
|
||||
return <button onClick={handleSubmit}>Submit</button>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://react.dev/learn/removing-effect-dependencies#should-this-code-move-to-an-event-handler](https://react.dev/learn/removing-effect-dependencies#should-this-code-move-to-an-event-handler)
|
||||
|
||||
### 5.8 Subscribe to Derived State
|
||||
### 5.4 Subscribe to Derived State
|
||||
|
||||
**Impact: MEDIUM (reduces re-render frequency)**
|
||||
|
||||
@ -1565,7 +1401,7 @@ function Sidebar() {
|
||||
}
|
||||
```
|
||||
|
||||
### 5.9 Use Functional setState Updates
|
||||
### 5.5 Use Functional setState Updates
|
||||
|
||||
**Impact: MEDIUM (prevents stale closures and unnecessary callback recreations)**
|
||||
|
||||
@ -1643,7 +1479,7 @@ function TodoList() {
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
|
||||
|
||||
### 5.10 Use Lazy State Initialization
|
||||
### 5.6 Use Lazy State Initialization
|
||||
|
||||
**Impact: MEDIUM (wasted computation on every render)**
|
||||
|
||||
@ -1697,7 +1533,7 @@ Use lazy initialization when computing initial values from localStorage/sessionS
|
||||
|
||||
For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
|
||||
|
||||
### 5.11 Use Transitions for Non-Urgent Updates
|
||||
### 5.7 Use Transitions for Non-Urgent Updates
|
||||
|
||||
**Impact: MEDIUM (maintains UI responsiveness)**
|
||||
|
||||
@ -1733,75 +1569,6 @@ function ScrollTracker() {
|
||||
}
|
||||
```
|
||||
|
||||
### 5.12 Use useRef for Transient Values
|
||||
|
||||
**Impact: MEDIUM (avoids unnecessary re-renders on frequent updates)**
|
||||
|
||||
When a value changes frequently and you don't want a re-render on every update (e.g., mouse trackers, intervals, transient flags), store it in `useRef` instead of `useState`. Keep component state for UI; use refs for temporary DOM-adjacent values. Updating a ref does not trigger a re-render.
|
||||
|
||||
**Incorrect: renders every update**
|
||||
|
||||
```tsx
|
||||
function Tracker() {
|
||||
const [lastX, setLastX] = useState(0)
|
||||
|
||||
useEffect(() => {
|
||||
const onMove = (e: MouseEvent) => setLastX(e.clientX)
|
||||
window.addEventListener('mousemove', onMove)
|
||||
return () => window.removeEventListener('mousemove', onMove)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: lastX,
|
||||
width: 8,
|
||||
height: 8,
|
||||
background: 'black',
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: no re-render for tracking**
|
||||
|
||||
```tsx
|
||||
function Tracker() {
|
||||
const lastXRef = useRef(0)
|
||||
const dotRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
const onMove = (e: MouseEvent) => {
|
||||
lastXRef.current = e.clientX
|
||||
const node = dotRef.current
|
||||
if (node) {
|
||||
node.style.transform = `translateX(${e.clientX}px)`
|
||||
}
|
||||
}
|
||||
window.addEventListener('mousemove', onMove)
|
||||
return () => window.removeEventListener('mousemove', onMove)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={dotRef}
|
||||
style={{
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: 0,
|
||||
width: 8,
|
||||
height: 8,
|
||||
background: 'black',
|
||||
transform: 'translateX(0px)',
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Rendering Performance
|
||||
@ -2031,33 +1798,7 @@ The inline script executes synchronously before showing the element, ensuring th
|
||||
|
||||
This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
|
||||
|
||||
### 6.6 Suppress Expected Hydration Mismatches
|
||||
|
||||
**Impact: LOW-MEDIUM (avoids noisy hydration warnings for known differences)**
|
||||
|
||||
In SSR frameworks (e.g., Next.js), some values are intentionally different on server vs client (random IDs, dates, locale/timezone formatting). For these *expected* mismatches, wrap the dynamic text in an element with `suppressHydrationWarning` to prevent noisy warnings. Do not use this to hide real bugs. Don’t overuse it.
|
||||
|
||||
**Incorrect: known mismatch warnings**
|
||||
|
||||
```tsx
|
||||
function Timestamp() {
|
||||
return <span>{new Date().toLocaleString()}</span>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: suppress expected mismatch only**
|
||||
|
||||
```tsx
|
||||
function Timestamp() {
|
||||
return (
|
||||
<span suppressHydrationWarning>
|
||||
{new Date().toLocaleString()}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 6.7 Use Activity Component for Show/Hide
|
||||
### 6.6 Use Activity Component for Show/Hide
|
||||
|
||||
**Impact: MEDIUM (preserves state/DOM)**
|
||||
|
||||
@ -2079,7 +1820,7 @@ function Dropdown({ isOpen }: Props) {
|
||||
|
||||
Avoids expensive re-renders and state loss.
|
||||
|
||||
### 6.8 Use Explicit Conditional Rendering
|
||||
### 6.7 Use Explicit Conditional Rendering
|
||||
|
||||
**Impact: LOW (prevents rendering 0 or NaN)**
|
||||
|
||||
@ -2115,80 +1856,6 @@ function Badge({ count }: { count: number }) {
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
|
||||
### 6.9 Use useTransition Over Manual Loading States
|
||||
|
||||
**Impact: LOW (reduces re-renders and improves code clarity)**
|
||||
|
||||
Use `useTransition` instead of manual `useState` for loading states. This provides built-in `isPending` state and automatically manages transitions.
|
||||
|
||||
**Incorrect: manual loading state**
|
||||
|
||||
```tsx
|
||||
function SearchResults() {
|
||||
const [query, setQuery] = useState('')
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const handleSearch = async (value: string) => {
|
||||
setIsLoading(true)
|
||||
setQuery(value)
|
||||
const data = await fetchResults(value)
|
||||
setResults(data)
|
||||
setIsLoading(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input onChange={(e) => handleSearch(e.target.value)} />
|
||||
{isLoading && <Spinner />}
|
||||
<ResultsList results={results} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: useTransition with built-in pending state**
|
||||
|
||||
```tsx
|
||||
import { useTransition, useState } from 'react'
|
||||
|
||||
function SearchResults() {
|
||||
const [query, setQuery] = useState('')
|
||||
const [results, setResults] = useState([])
|
||||
const [isPending, startTransition] = useTransition()
|
||||
|
||||
const handleSearch = (value: string) => {
|
||||
setQuery(value) // Update input immediately
|
||||
|
||||
startTransition(async () => {
|
||||
// Fetch and update results
|
||||
const data = await fetchResults(value)
|
||||
setResults(data)
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input onChange={(e) => handleSearch(e.target.value)} />
|
||||
{isPending && <Spinner />}
|
||||
<ResultsList results={results} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
- **Automatic pending state**: No need to manually manage `setIsLoading(true/false)`
|
||||
|
||||
- **Error resilience**: Pending state correctly resets even if the transition throws
|
||||
|
||||
- **Better responsiveness**: Keeps the UI responsive during updates
|
||||
|
||||
- **Interrupt handling**: New transitions automatically cancel pending ones
|
||||
|
||||
Reference: [https://react.dev/reference/react/useTransition](https://react.dev/reference/react/useTransition)
|
||||
|
||||
---
|
||||
|
||||
## 7. JavaScript Performance
|
||||
@ -2197,28 +1864,16 @@ Reference: [https://react.dev/reference/react/useTransition](https://react.dev/r
|
||||
|
||||
Micro-optimizations for hot paths can add up to meaningful improvements.
|
||||
|
||||
### 7.1 Avoid Layout Thrashing
|
||||
### 7.1 Batch DOM CSS Changes
|
||||
|
||||
**Impact: MEDIUM (prevents forced synchronous layouts and reduces performance bottlenecks)**
|
||||
**Impact: MEDIUM (reduces reflows/repaints)**
|
||||
|
||||
Avoid interleaving style writes with layout reads. When you read a layout property (like `offsetWidth`, `getBoundingClientRect()`, or `getComputedStyle()`) between style changes, the browser is forced to trigger a synchronous reflow.
|
||||
|
||||
**This is OK: browser batches style changes**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
// Each line invalidates style, but browser batches the recalculation
|
||||
element.style.width = '100px'
|
||||
element.style.height = '200px'
|
||||
element.style.backgroundColor = 'blue'
|
||||
element.style.border = '1px solid black'
|
||||
}
|
||||
```
|
||||
|
||||
**Incorrect: interleaved reads and writes force reflows**
|
||||
|
||||
```typescript
|
||||
function layoutThrashing(element: HTMLElement) {
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.style.width = '100px'
|
||||
const width = element.offsetWidth // Forces reflow
|
||||
element.style.height = '200px'
|
||||
@ -2228,63 +1883,18 @@ function layoutThrashing(element: HTMLElement) {
|
||||
|
||||
**Correct: batch writes, then read once**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
// Batch all writes together
|
||||
element.style.width = '100px'
|
||||
element.style.height = '200px'
|
||||
element.style.backgroundColor = 'blue'
|
||||
element.style.border = '1px solid black'
|
||||
|
||||
// Read after all writes are done (single reflow)
|
||||
const { width, height } = element.getBoundingClientRect()
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: batch reads, then writes**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.classList.add('highlighted-box')
|
||||
|
||||
|
||||
const { width, height } = element.getBoundingClientRect()
|
||||
}
|
||||
```
|
||||
|
||||
**Better: use CSS classes**
|
||||
|
||||
**React example:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: interleaving style changes with layout queries
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current && isHighlighted) {
|
||||
ref.current.style.width = '100px'
|
||||
const width = ref.current.offsetWidth // Forces layout
|
||||
ref.current.style.height = '200px'
|
||||
}
|
||||
}, [isHighlighted])
|
||||
|
||||
return <div ref={ref}>Content</div>
|
||||
}
|
||||
|
||||
// Correct: toggle class
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
return (
|
||||
<div className={isHighlighted ? 'highlighted-box' : ''}>
|
||||
Content
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Prefer CSS classes over inline styles when possible. CSS files are cached by the browser, and classes provide better separation of concerns and are easier to maintain.
|
||||
|
||||
See [this gist](https://gist.github.com/paulirish/5d52fb081b3570c81e3a) and [CSS Triggers](https://csstriggers.com/) for more information on layout-forcing operations.
|
||||
|
||||
### 7.2 Build Index Maps for Repeated Lookups
|
||||
|
||||
**Impact: LOW-MEDIUM (1M ops to 2K ops)**
|
||||
@ -2812,45 +2422,7 @@ const sorted = [...items].sort((a, b) => a.value - b.value)
|
||||
|
||||
Advanced patterns for specific cases that require careful implementation.
|
||||
|
||||
### 8.1 Initialize App Once, Not Per Mount
|
||||
|
||||
**Impact: LOW-MEDIUM (avoids duplicate init in development)**
|
||||
|
||||
Do not put app-wide initialization that must run once per app load inside `useEffect([])` of a component. Components can remount and effects will re-run. Use a module-level guard or top-level init in the entry module instead.
|
||||
|
||||
**Incorrect: runs twice in dev, re-runs on remount**
|
||||
|
||||
```tsx
|
||||
function Comp() {
|
||||
useEffect(() => {
|
||||
loadFromStorage()
|
||||
checkAuthToken()
|
||||
}, [])
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: once per app load**
|
||||
|
||||
```tsx
|
||||
let didInit = false
|
||||
|
||||
function Comp() {
|
||||
useEffect(() => {
|
||||
if (didInit) return
|
||||
didInit = true
|
||||
loadFromStorage()
|
||||
checkAuthToken()
|
||||
}, [])
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://react.dev/learn/you-might-not-need-an-effect#initializing-the-application](https://react.dev/learn/you-might-not-need-an-effect#initializing-the-application)
|
||||
|
||||
### 8.2 Store Event Handlers in Refs
|
||||
### 8.1 Store Event Handlers in Refs
|
||||
|
||||
**Impact: LOW (stable subscriptions)**
|
||||
|
||||
@ -2886,12 +2458,24 @@ function useWindowEvent(event: string, handler: (e) => void) {
|
||||
|
||||
`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
|
||||
|
||||
### 8.3 useEffectEvent for Stable Callback Refs
|
||||
### 8.2 useLatest for Stable Callback Refs
|
||||
|
||||
**Impact: LOW (prevents effect re-runs)**
|
||||
|
||||
Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
function useLatest<T>(value: T) {
|
||||
const ref = useRef(value)
|
||||
useLayoutEffect(() => {
|
||||
ref.current = value
|
||||
}, [value])
|
||||
return ref
|
||||
}
|
||||
```
|
||||
|
||||
**Incorrect: effect re-runs on every callback change**
|
||||
|
||||
```tsx
|
||||
@ -2905,17 +2489,15 @@ function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
}
|
||||
```
|
||||
|
||||
**Correct: using React's useEffectEvent**
|
||||
**Correct: stable effect, fresh callback**
|
||||
|
||||
```tsx
|
||||
import { useEffectEvent } from 'react';
|
||||
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
const onSearchEvent = useEffectEvent(onSearch)
|
||||
const onSearchRef = useLatest(onSearch)
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearchEvent(query), 300)
|
||||
const timeout = setTimeout(() => onSearchRef.current(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query])
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ metadata:
|
||||
|
||||
# Vercel React Best Practices
|
||||
|
||||
Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 57 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
|
||||
Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
|
||||
|
||||
## When to Apply
|
||||
|
||||
@ -53,10 +53,8 @@ Reference these guidelines when:
|
||||
|
||||
### 3. Server-Side Performance (HIGH)
|
||||
|
||||
- `server-auth-actions` - Authenticate server actions like API routes
|
||||
- `server-cache-react` - Use React.cache() for per-request deduplication
|
||||
- `server-cache-lru` - Use LRU cache for cross-request caching
|
||||
- `server-dedup-props` - Avoid duplicate serialization in RSC props
|
||||
- `server-serialization` - Minimize data passed to client components
|
||||
- `server-parallel-fetching` - Restructure components to parallelize fetches
|
||||
- `server-after-nonblocking` - Use after() for non-blocking operations
|
||||
@ -65,23 +63,16 @@ Reference these guidelines when:
|
||||
|
||||
- `client-swr-dedup` - Use SWR for automatic request deduplication
|
||||
- `client-event-listeners` - Deduplicate global event listeners
|
||||
- `client-passive-event-listeners` - Use passive listeners for scroll
|
||||
- `client-localstorage-schema` - Version and minimize localStorage data
|
||||
|
||||
### 5. Re-render Optimization (MEDIUM)
|
||||
|
||||
- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
|
||||
- `rerender-memo` - Extract expensive work into memoized components
|
||||
- `rerender-memo-with-default-value` - Hoist default non-primitive props
|
||||
- `rerender-dependencies` - Use primitive dependencies in effects
|
||||
- `rerender-derived-state` - Subscribe to derived booleans, not raw values
|
||||
- `rerender-derived-state-no-effect` - Derive state during render, not effects
|
||||
- `rerender-functional-setstate` - Use functional setState for stable callbacks
|
||||
- `rerender-lazy-state-init` - Pass function to useState for expensive values
|
||||
- `rerender-simple-expression-in-memo` - Avoid memo for simple primitives
|
||||
- `rerender-move-effect-to-event` - Put interaction logic in event handlers
|
||||
- `rerender-transitions` - Use startTransition for non-urgent updates
|
||||
- `rerender-use-ref-transient-values` - Use refs for transient frequent values
|
||||
|
||||
### 6. Rendering Performance (MEDIUM)
|
||||
|
||||
@ -90,10 +81,8 @@ Reference these guidelines when:
|
||||
- `rendering-hoist-jsx` - Extract static JSX outside components
|
||||
- `rendering-svg-precision` - Reduce SVG coordinate precision
|
||||
- `rendering-hydration-no-flicker` - Use inline script for client-only data
|
||||
- `rendering-hydration-suppress-warning` - Suppress expected mismatches
|
||||
- `rendering-activity` - Use Activity component for show/hide
|
||||
- `rendering-conditional-render` - Use ternary, not && for conditionals
|
||||
- `rendering-usetransition-loading` - Prefer useTransition for loading state
|
||||
|
||||
### 7. JavaScript Performance (LOW-MEDIUM)
|
||||
|
||||
@ -113,7 +102,6 @@ Reference these guidelines when:
|
||||
### 8. Advanced Patterns (LOW)
|
||||
|
||||
- `advanced-event-handler-refs` - Store event handlers in refs
|
||||
- `advanced-init-once` - Initialize app once per app load
|
||||
- `advanced-use-latest` - useLatest for stable callback refs
|
||||
|
||||
## How to Use
|
||||
@ -123,6 +111,7 @@ Read individual rule files for detailed explanations and code examples:
|
||||
```
|
||||
rules/async-parallel.md
|
||||
rules/bundle-barrel-imports.md
|
||||
rules/_sections.md
|
||||
```
|
||||
|
||||
Each rule file contains:
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
---
|
||||
title: Initialize App Once, Not Per Mount
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids duplicate init in development
|
||||
tags: initialization, useEffect, app-startup, side-effects
|
||||
---
|
||||
|
||||
## Initialize App Once, Not Per Mount
|
||||
|
||||
Do not put app-wide initialization that must run once per app load inside `useEffect([])` of a component. Components can remount and effects will re-run. Use a module-level guard or top-level init in the entry module instead.
|
||||
|
||||
**Incorrect (runs twice in dev, re-runs on remount):**
|
||||
|
||||
```tsx
|
||||
function Comp() {
|
||||
useEffect(() => {
|
||||
loadFromStorage()
|
||||
checkAuthToken()
|
||||
}, [])
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (once per app load):**
|
||||
|
||||
```tsx
|
||||
let didInit = false
|
||||
|
||||
function Comp() {
|
||||
useEffect(() => {
|
||||
if (didInit) return
|
||||
didInit = true
|
||||
loadFromStorage()
|
||||
checkAuthToken()
|
||||
}, [])
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [Initializing the application](https://react.dev/learn/you-might-not-need-an-effect#initializing-the-application)
|
||||
@ -1,30 +0,0 @@
|
||||
---
|
||||
title: Suppress Expected Hydration Mismatches
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids noisy hydration warnings for known differences
|
||||
tags: rendering, hydration, ssr, nextjs
|
||||
---
|
||||
|
||||
## Suppress Expected Hydration Mismatches
|
||||
|
||||
In SSR frameworks (e.g., Next.js), some values are intentionally different on server vs client (random IDs, dates, locale/timezone formatting). For these *expected* mismatches, wrap the dynamic text in an element with `suppressHydrationWarning` to prevent noisy warnings. Do not use this to hide real bugs. Don’t overuse it.
|
||||
|
||||
**Incorrect (known mismatch warnings):**
|
||||
|
||||
```tsx
|
||||
function Timestamp() {
|
||||
return <span>{new Date().toLocaleString()}</span>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (suppress expected mismatch only):**
|
||||
|
||||
```tsx
|
||||
function Timestamp() {
|
||||
return (
|
||||
<span suppressHydrationWarning>
|
||||
{new Date().toLocaleString()}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
```
|
||||
@ -1,75 +0,0 @@
|
||||
---
|
||||
title: Use useTransition Over Manual Loading States
|
||||
impact: LOW
|
||||
impactDescription: reduces re-renders and improves code clarity
|
||||
tags: rendering, transitions, useTransition, loading, state
|
||||
---
|
||||
|
||||
## Use useTransition Over Manual Loading States
|
||||
|
||||
Use `useTransition` instead of manual `useState` for loading states. This provides built-in `isPending` state and automatically manages transitions.
|
||||
|
||||
**Incorrect (manual loading state):**
|
||||
|
||||
```tsx
|
||||
function SearchResults() {
|
||||
const [query, setQuery] = useState('')
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const handleSearch = async (value: string) => {
|
||||
setIsLoading(true)
|
||||
setQuery(value)
|
||||
const data = await fetchResults(value)
|
||||
setResults(data)
|
||||
setIsLoading(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input onChange={(e) => handleSearch(e.target.value)} />
|
||||
{isLoading && <Spinner />}
|
||||
<ResultsList results={results} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (useTransition with built-in pending state):**
|
||||
|
||||
```tsx
|
||||
import { useTransition, useState } from 'react'
|
||||
|
||||
function SearchResults() {
|
||||
const [query, setQuery] = useState('')
|
||||
const [results, setResults] = useState([])
|
||||
const [isPending, startTransition] = useTransition()
|
||||
|
||||
const handleSearch = (value: string) => {
|
||||
setQuery(value) // Update input immediately
|
||||
|
||||
startTransition(async () => {
|
||||
// Fetch and update results
|
||||
const data = await fetchResults(value)
|
||||
setResults(data)
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input onChange={(e) => handleSearch(e.target.value)} />
|
||||
{isPending && <Spinner />}
|
||||
<ResultsList results={results} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
- **Automatic pending state**: No need to manually manage `setIsLoading(true/false)`
|
||||
- **Error resilience**: Pending state correctly resets even if the transition throws
|
||||
- **Better responsiveness**: Keeps the UI responsive during updates
|
||||
- **Interrupt handling**: New transitions automatically cancel pending ones
|
||||
|
||||
Reference: [useTransition](https://react.dev/reference/react/useTransition)
|
||||
@ -1,40 +0,0 @@
|
||||
---
|
||||
title: Calculate Derived State During Rendering
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids redundant renders and state drift
|
||||
tags: rerender, derived-state, useEffect, state
|
||||
---
|
||||
|
||||
## Calculate Derived State During Rendering
|
||||
|
||||
If a value can be computed from current props/state, do not store it in state or update it in an effect. Derive it during render to avoid extra renders and state drift. Do not set state in effects solely in response to prop changes; prefer derived values or keyed resets instead.
|
||||
|
||||
**Incorrect (redundant state and effect):**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [firstName, setFirstName] = useState('First')
|
||||
const [lastName, setLastName] = useState('Last')
|
||||
const [fullName, setFullName] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
setFullName(firstName + ' ' + lastName)
|
||||
}, [firstName, lastName])
|
||||
|
||||
return <p>{fullName}</p>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (derive during render):**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [firstName, setFirstName] = useState('First')
|
||||
const [lastName, setLastName] = useState('Last')
|
||||
const fullName = firstName + ' ' + lastName
|
||||
|
||||
return <p>{fullName}</p>
|
||||
}
|
||||
```
|
||||
|
||||
References: [You Might Not Need an Effect](https://react.dev/learn/you-might-not-need-an-effect)
|
||||
@ -1,45 +0,0 @@
|
||||
---
|
||||
title: Put Interaction Logic in Event Handlers
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids effect re-runs and duplicate side effects
|
||||
tags: rerender, useEffect, events, side-effects, dependencies
|
||||
---
|
||||
|
||||
## Put Interaction Logic in Event Handlers
|
||||
|
||||
If a side effect is triggered by a specific user action (submit, click, drag), run it in that event handler. Do not model the action as state + effect; it makes effects re-run on unrelated changes and can duplicate the action.
|
||||
|
||||
**Incorrect (event modeled as state + effect):**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const [submitted, setSubmitted] = useState(false)
|
||||
const theme = useContext(ThemeContext)
|
||||
|
||||
useEffect(() => {
|
||||
if (submitted) {
|
||||
post('/api/register')
|
||||
showToast('Registered', theme)
|
||||
}
|
||||
}, [submitted, theme])
|
||||
|
||||
return <button onClick={() => setSubmitted(true)}>Submit</button>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (do it in the handler):**
|
||||
|
||||
```tsx
|
||||
function Form() {
|
||||
const theme = useContext(ThemeContext)
|
||||
|
||||
function handleSubmit() {
|
||||
post('/api/register')
|
||||
showToast('Registered', theme)
|
||||
}
|
||||
|
||||
return <button onClick={handleSubmit}>Submit</button>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [Should this code move to an event handler?](https://react.dev/learn/removing-effect-dependencies#should-this-code-move-to-an-event-handler)
|
||||
@ -1,73 +0,0 @@
|
||||
---
|
||||
title: Use useRef for Transient Values
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids unnecessary re-renders on frequent updates
|
||||
tags: rerender, useref, state, performance
|
||||
---
|
||||
|
||||
## Use useRef for Transient Values
|
||||
|
||||
When a value changes frequently and you don't want a re-render on every update (e.g., mouse trackers, intervals, transient flags), store it in `useRef` instead of `useState`. Keep component state for UI; use refs for temporary DOM-adjacent values. Updating a ref does not trigger a re-render.
|
||||
|
||||
**Incorrect (renders every update):**
|
||||
|
||||
```tsx
|
||||
function Tracker() {
|
||||
const [lastX, setLastX] = useState(0)
|
||||
|
||||
useEffect(() => {
|
||||
const onMove = (e: MouseEvent) => setLastX(e.clientX)
|
||||
window.addEventListener('mousemove', onMove)
|
||||
return () => window.removeEventListener('mousemove', onMove)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: lastX,
|
||||
width: 8,
|
||||
height: 8,
|
||||
background: 'black',
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (no re-render for tracking):**
|
||||
|
||||
```tsx
|
||||
function Tracker() {
|
||||
const lastXRef = useRef(0)
|
||||
const dotRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
const onMove = (e: MouseEvent) => {
|
||||
lastXRef.current = e.clientX
|
||||
const node = dotRef.current
|
||||
if (node) {
|
||||
node.style.transform = `translateX(${e.clientX}px)`
|
||||
}
|
||||
}
|
||||
window.addEventListener('mousemove', onMove)
|
||||
return () => window.removeEventListener('mousemove', onMove)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={dotRef}
|
||||
style={{
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: 0,
|
||||
width: 8,
|
||||
height: 8,
|
||||
background: 'black',
|
||||
transform: 'translateX(0px)',
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
```
|
||||
23
.github/workflows/autofix.yml
vendored
23
.github/workflows/autofix.yml
vendored
@ -79,29 +79,6 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install web dependencies
|
||||
run: |
|
||||
cd web
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm lint:fix || true
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -125,7 +125,7 @@ jobs:
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
# Notes: `large_language_model.py`
|
||||
|
||||
## Purpose
|
||||
|
||||
Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to
|
||||
bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`).
|
||||
|
||||
## Key behaviors / invariants
|
||||
|
||||
- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from
|
||||
the first yielded `LLMResultChunk`.
|
||||
- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by
|
||||
`_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`.
|
||||
- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking
|
||||
patterns (IDs anchored to first chunk, every chunk, or missing entirely).
|
||||
- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to
|
||||
surface invalid delta sequences explicitly.
|
||||
- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging.
|
||||
- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must
|
||||
be re-attached in this layer before callbacks/consumers use them.
|
||||
- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation
|
||||
unless `callback.raise_error` is true.
|
||||
|
||||
## Test focus
|
||||
|
||||
- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and
|
||||
patches `_gen_tool_call_id` for deterministic IDs.
|
||||
@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
|
||||
@ -27,9 +27,7 @@ ignore_imports =
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
@ -59,252 +57,6 @@ ignore_imports =
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
services
|
||||
tasks
|
||||
core.agent
|
||||
core.app
|
||||
core.base
|
||||
core.callback_handler
|
||||
core.datasource
|
||||
core.db
|
||||
core.entities
|
||||
core.errors
|
||||
core.extension
|
||||
core.external_data_tool
|
||||
core.file
|
||||
core.helper
|
||||
core.hosting_configuration
|
||||
core.indexing_runner
|
||||
core.llm_generator
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.layers.observability -> configs
|
||||
core.workflow.graph_engine.layers.observability -> extensions.otel.runtime
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager
|
||||
core.workflow.graph_engine.worker_management.worker_pool -> configs
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> configs
|
||||
core.workflow.nodes.document_extractor.node -> core.file.file_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
core.workflow.nodes.http_request.executor -> configs
|
||||
core.workflow.nodes.http_request.executor -> core.file.file_manager
|
||||
core.workflow.nodes.http_request.node -> configs
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
core.workflow.workflow_entry -> configs
|
||||
core.workflow.workflow_entry -> models.workflow
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.node_events.node -> core.file
|
||||
core.workflow.nodes.agent.agent_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file.enums
|
||||
core.workflow.nodes.document_extractor.node -> core.file
|
||||
core.workflow.nodes.http_request.executor -> core.file.enums
|
||||
core.workflow.nodes.http_request.node -> core.file
|
||||
core.workflow.nodes.http_request.node -> core.file.file_manager
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
|
||||
core.workflow.nodes.list_operator.node -> core.file
|
||||
core.workflow.nodes.llm.file_saver -> core.file
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.file
|
||||
core.workflow.nodes.llm.node -> core.file.file_manager
|
||||
core.workflow.nodes.llm.node -> core.file.models
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
|
||||
core.workflow.nodes.protocols -> core.file
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
|
||||
core.workflow.nodes.tool.tool_node -> core.file
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.tool.tool_node -> models
|
||||
core.workflow.nodes.trigger_webhook.node -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file.file_manager
|
||||
core.workflow.system_variable -> core.file.models
|
||||
core.workflow.utils.condition.processor -> core.file
|
||||
core.workflow.utils.condition.processor -> core.file.file_manager
|
||||
core.workflow.workflow_entry -> core.file.models
|
||||
core.workflow.workflow_type_encoder -> core.file.models
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.node -> core.helper.code_executor
|
||||
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.conversation_variable_updater -> core.variables
|
||||
core.workflow.graph_engine.entities.commands -> core.variables.variables
|
||||
core.workflow.nodes.agent.agent_node -> core.variables.segments
|
||||
core.workflow.nodes.answer.answer_node -> core.variables
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
|
||||
core.workflow.nodes.list_operator.node -> core.variables
|
||||
core.workflow.nodes.list_operator.node -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.variables
|
||||
core.workflow.nodes.loop.loop_node -> core.variables
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.segments
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.variables
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.types
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
|
||||
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
|
||||
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
|
||||
core.workflow.nodes.variable_assigner.v1.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
|
||||
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
|
||||
core.workflow.runtime.read_only_wrappers -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables
|
||||
core.workflow.runtime.variable_pool -> core.variables.consts
|
||||
core.workflow.runtime.variable_pool -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables.variables
|
||||
core.workflow.utils.condition.processor -> core.variables
|
||||
core.workflow.utils.condition.processor -> core.variables.segments
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
||||
171
api/README.md
171
api/README.md
@ -1,6 +1,6 @@
|
||||
# Dify Backend API
|
||||
|
||||
## Setup and Run
|
||||
## Usage
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
@ -8,77 +8,48 @@
|
||||
> [`uv`](https://docs.astral.sh/uv/) as the package manager
|
||||
> for Dify API backend service.
|
||||
|
||||
`uv` and `pnpm` are required to run the setup and development commands below.
|
||||
1. Start the docker-compose stack
|
||||
|
||||
### Using scripts (recommended)
|
||||
|
||||
The scripts resolve paths relative to their location, so you can run them from anywhere.
|
||||
|
||||
1. Run setup (copies env files and installs dependencies).
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
./dev/setup
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below).
|
||||
1. Copy `.env.example` to `.env`
|
||||
|
||||
1. Start middleware (PostgreSQL/Redis/Weaviate).
|
||||
|
||||
```bash
|
||||
./dev/start-docker-compose
|
||||
```cli
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first).
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
```bash
|
||||
./dev/start-api
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service.
|
||||
bash for Mac
|
||||
|
||||
```bash
|
||||
./dev/start-web
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
1. Create environment.
|
||||
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks).
|
||||
|
||||
```bash
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
|
||||
First, you need to add the uv package manager, if you don't have it already.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
@ -86,96 +57,60 @@ These commands assume you start from the repository root.
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install API dependencies.
|
||||
1. Install dependencies
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
1. Install web dependencies.
|
||||
1. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
```
|
||||
|
||||
1. Start backend
|
||||
|
||||
```bash
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
1. Start Dify [web](../web) service.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
1. Setup your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
- Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
|
||||
```bash
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Application configuration definitions, including file access settings.
|
||||
|
||||
Invariants:
|
||||
- File access settings drive signed URL expiration and base URLs.
|
||||
|
||||
Tests:
|
||||
- Config parsing tests under tests/unit_tests/configs.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
- Registers file-related API namespaces and routes for files service.
|
||||
- Includes app-assets and sandbox archive proxy controllers.
|
||||
|
||||
Invariants:
|
||||
- files_ns must include all file controller modules to register routes.
|
||||
|
||||
Tests:
|
||||
- Coverage via controller unit tests and route registration smoke checks.
|
||||
@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- App assets download proxy endpoint (signed URL verification, stream from storage).
|
||||
|
||||
Invariants:
|
||||
- Validates AssetPath fields (UUIDs, asset_type allowlist).
|
||||
- Verifies tenant-scoped signature and expiration before reading storage.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Missing files return NotFound.
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Verify signature validation and invalid/expired cases.
|
||||
@ -1,13 +0,0 @@
|
||||
Summary:
|
||||
- App assets upload proxy endpoint (signed URL verification, upload to storage).
|
||||
|
||||
Invariants:
|
||||
- Validates AssetPath fields (UUIDs, asset_type allowlist).
|
||||
- Verifies tenant-scoped signature and expiration before writing storage.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Verify signature validation and invalid/expired cases.
|
||||
@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox archive upload/download proxy endpoints (signed URL verification, stream to storage).
|
||||
|
||||
Invariants:
|
||||
- Validates tenant_id and sandbox_id UUIDs.
|
||||
- Verifies tenant-scoped signature and expiration before storage access.
|
||||
- URL uses expires_at/nonce/sign query params.
|
||||
|
||||
Edge Cases:
|
||||
- Missing archive returns NotFound.
|
||||
- Invalid signature or expired link returns Forbidden.
|
||||
|
||||
Tests:
|
||||
- Add unit tests for signature validation if needed.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Collects file assets and emits FileAsset entries with storage keys.
|
||||
|
||||
Invariants:
|
||||
- Storage keys are derived via AppAssetStorage for draft files.
|
||||
|
||||
Tests:
|
||||
- Covered by asset build pipeline tests.
|
||||
@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Builds skill artifacts from markdown assets and uploads resolved outputs.
|
||||
|
||||
Invariants:
|
||||
- Reads draft asset content via AppAssetStorage refs.
|
||||
- Writes resolved artifacts via AppAssetStorage refs.
|
||||
- FileAsset storage keys are derived via AppAssetStorage.
|
||||
|
||||
Edge Cases:
|
||||
- Missing or invalid JSON content yields empty skill content/metadata.
|
||||
|
||||
Tests:
|
||||
- Build pipeline unit tests covering compile/upload paths.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Converts AppAssetFileTree to FileAsset items for packaging.
|
||||
|
||||
Invariants:
|
||||
- Storage keys for assets are derived via AppAssetStorage.
|
||||
|
||||
Tests:
|
||||
- Used in packaging/service tests for asset bundles.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Builds AssetItem entries for asset trees using AssetPath-derived storage keys.
|
||||
|
||||
Invariants:
|
||||
- Uses AssetPath to compute draft storage keys.
|
||||
|
||||
Tests:
|
||||
- Covered by asset parsing and packaging tests.
|
||||
@ -1,20 +0,0 @@
|
||||
Summary:
|
||||
- Defines AssetPath facade + typed asset path classes for app-asset storage access.
|
||||
- Maps asset paths to storage keys and generates presigned or signed-proxy URLs.
|
||||
- Signs proxy URLs using tenant private keys and enforces expiration.
|
||||
- Exposes app_asset_storage singleton for reuse.
|
||||
|
||||
Invariants:
|
||||
- AssetPathBase fields (tenant_id/app_id/resource_id/node_id) must be UUIDs.
|
||||
- AssetPath.from_components enforces valid types and resolved node_id presence.
|
||||
- Storage keys are derived internally via AssetPathBase.get_storage_key; callers never supply raw paths.
|
||||
- AppAssetStorage.storage returns the cached presign wrapper (not the raw storage).
|
||||
|
||||
Edge Cases:
|
||||
- Storage backends without presign support must fall back to signed proxy URLs.
|
||||
- Signed proxy verification enforces expiration and tenant-scoped signing keys.
|
||||
- Upload URLs also fall back to signed proxy endpoints when presign is unsupported.
|
||||
- load_or_none treats SilentStorage "File Not Found" bytes as missing.
|
||||
|
||||
Tests:
|
||||
- Unit tests for ref validation, storage key mapping, and signed URL verification.
|
||||
@ -1,10 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Extracts asset files from a zip and persists them into app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Rejects path traversal/absolute/backslash paths.
|
||||
- Saves extracted files via AppAssetStorage draft refs.
|
||||
|
||||
Tests:
|
||||
- Zip security edge cases and tree construction tests.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Downloads published app asset zip into sandbox and extracts it.
|
||||
|
||||
Invariants:
|
||||
- Uses AppAssetStorage to generate download URLs for build zips (internal URL).
|
||||
|
||||
Tests:
|
||||
- Sandbox initialization integration tests.
|
||||
@ -1,12 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Downloads draft/resolved assets into sandbox for draft execution.
|
||||
|
||||
Invariants:
|
||||
- Uses AppAssetStorage to generate download URLs for draft/resolved refs (internal URL).
|
||||
|
||||
Edge Cases:
|
||||
- No nodes -> returns early.
|
||||
|
||||
Tests:
|
||||
- Sandbox draft initialization tests.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox lifecycle wrapper (ready/cancel/fail signals, mount/unmount, release).
|
||||
|
||||
Invariants:
|
||||
- wait_ready raises with the original initialization error as the cause.
|
||||
- release always attempts unmount and environment release, logging failures.
|
||||
|
||||
Tests:
|
||||
- Covered by sandbox lifecycle/unit tests and workflow execution error handling.
|
||||
@ -1,2 +0,0 @@
|
||||
Summary:
|
||||
- Sandbox security helper modules.
|
||||
@ -1,13 +0,0 @@
|
||||
Summary:
|
||||
- Generates and verifies signed URLs for sandbox archive upload/download.
|
||||
|
||||
Invariants:
|
||||
- tenant_id and sandbox_id must be UUIDs.
|
||||
- Signatures are tenant-scoped and include operation, expiry, and nonce.
|
||||
|
||||
Edge Cases:
|
||||
- Missing tenant private key raises ValueError.
|
||||
- Expired or tampered signatures are rejected.
|
||||
|
||||
Tests:
|
||||
- Add unit tests if sandbox archive signature behavior expands.
|
||||
@ -1,12 +0,0 @@
|
||||
Summary:
|
||||
- Manages sandbox archive uploads/downloads for workspace persistence.
|
||||
|
||||
Invariants:
|
||||
- Archive storage key is sandbox/<tenant_id>/<sandbox_id>.tar.gz.
|
||||
- Signed URLs are tenant-scoped and use external files URL.
|
||||
|
||||
Edge Cases:
|
||||
- Missing archive skips mount.
|
||||
|
||||
Tests:
|
||||
- Covered indirectly via sandbox integration tests.
|
||||
@ -1,9 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Loads/saves skill bundles to app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Skill bundles use AppAssetStorage refs and JSON serialization.
|
||||
|
||||
Tests:
|
||||
- Covered by skill bundle build/load unit tests.
|
||||
@ -1,14 +0,0 @@
|
||||
Summary:
|
||||
- App asset CRUD, publish/build pipeline, and presigned URL generation.
|
||||
|
||||
Invariants:
|
||||
- Asset storage access goes through AppAssetStorage + AssetPath, using app_asset_storage singleton.
|
||||
- Tree operations require tenant/app scoping and lock for mutation.
|
||||
- Asset zips are packaged via raw storage with storage keys from AppAssetStorage.
|
||||
|
||||
Edge Cases:
|
||||
- File nodes larger than preview limit are rejected.
|
||||
- Deletion runs asynchronously; storage failures are logged.
|
||||
|
||||
Tests:
|
||||
- Unit tests for storage URL generation and publish/build flows.
|
||||
@ -1,10 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Imports app bundles, including asset extraction into app asset storage.
|
||||
|
||||
Invariants:
|
||||
- Asset imports respect zip security checks and tenant/app scoping.
|
||||
- Draft asset packaging uses AppAssetStorage for key mapping.
|
||||
|
||||
Tests:
|
||||
- Bundle import unit tests and zip validation coverage.
|
||||
@ -1,6 +0,0 @@
|
||||
Summary:
|
||||
Summary:
|
||||
- Unit tests for AppAssetStorage ref validation, key mapping, and signing.
|
||||
|
||||
Tests:
|
||||
- Covers valid/invalid refs, signature verify, expiration handling, and proxy URL generation.
|
||||
19
api/app.py
19
api/app.py
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -9,15 +8,10 @@ def is_db_command() -> bool:
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@ -28,15 +22,8 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = flask_app.extensions["celery"]
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
@ -9,7 +8,6 @@ from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,18 +60,14 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return socketio_app, app
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
@ -87,7 +81,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
@ -135,7 +128,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_proxy_fix,
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
|
||||
340
api/commands.py
340
api/commands.py
@ -951,346 +951,6 @@ def clean_workflow_runs(
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"archive-workflow-runs",
|
||||
help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
|
||||
)
|
||||
@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
|
||||
@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created at or after this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created before this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
|
||||
@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
|
||||
@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
|
||||
def archive_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
delete_after_archive: bool,
|
||||
):
|
||||
"""
|
||||
Archive workflow runs for paid plan tenants older than the specified days.
|
||||
|
||||
This command archives the following tables to storage:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
|
||||
The workflow_runs and workflow_app_logs tables are preserved for UI listing.
|
||||
"""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
run_started_at = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting workflow run archiving at {run_started_at.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
|
||||
return
|
||||
if from_days_ago <= to_days_ago:
|
||||
click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
|
||||
return
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if start_from and end_before and start_from >= end_before:
|
||||
click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
|
||||
return
|
||||
if workers < 1:
|
||||
click.echo(click.style("workers must be at least 1.", fg="red"))
|
||||
return
|
||||
|
||||
archiver = WorkflowRunArchiver(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
workers=workers,
|
||||
tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
delete_after_archive=delete_after_archive,
|
||||
)
|
||||
summary = archiver.run()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
|
||||
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
|
||||
f"time={summary.total_elapsed_time:.2f}s",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
|
||||
run_finished_at = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = run_finished_at - run_started_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run archiving completed. start={run_started_at.isoformat()} "
|
||||
f"end={run_finished_at.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"restore-workflow-runs",
|
||||
help="Restore archived workflow runs from S3-compatible storage.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to restore.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
|
||||
def restore_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
workers: int,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Restore an archived workflow run from storage to the database.
|
||||
|
||||
This restores the following tables:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
"""
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch restore.")
|
||||
if workers < 1:
|
||||
raise click.BadParameter("workers must be at least 1")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
|
||||
if run_id:
|
||||
results = [restorer.restore_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = restorer.restore_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"delete-archived-workflow-runs",
|
||||
help="Delete archived workflow runs from the database.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to delete.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
|
||||
def delete_archived_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Delete archived workflow runs from the database.
|
||||
"""
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch delete.")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting delete of {target_desc} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
|
||||
if run_id:
|
||||
results = [deleter.delete_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = deleter.delete_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
|
||||
f"workflow run {result.run_id} (tenant={result.tenant_id})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to delete workflow run {result.run_id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
|
||||
def clear_orphaned_file_records(force: bool):
|
||||
|
||||
@ -90,10 +90,7 @@ class DifyConfig(
|
||||
"Defaults to api/bin when unset."
|
||||
),
|
||||
)
|
||||
DIFY_PORT: int = Field(
|
||||
default=5001,
|
||||
description="Port used by Dify to communicate with the host machine.",
|
||||
)
|
||||
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
|
||||
@ -1240,13 +1240,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1366,7 +1359,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@ -3,7 +3,6 @@ Flask App Context - Flask implementation of AppContext interface.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
@ -119,7 +118,6 @@ class FlaskExecutionContext:
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
self._local = threading.local()
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
@ -138,39 +136,47 @@ class FlaskExecutionContext:
|
||||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
cm = self._app_context.enter()
|
||||
self._local.cm = cm
|
||||
cm.__enter__()
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
cm = getattr(self._local, "cm", None)
|
||||
if cm is not None:
|
||||
cm.__exit__(*args)
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -32,7 +32,6 @@ for module_name in RESOURCE_MODULES:
|
||||
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
# Sandbox file browser
|
||||
from . import (
|
||||
admin,
|
||||
apikey,
|
||||
@ -40,7 +39,6 @@ from . import (
|
||||
feature,
|
||||
init_validate,
|
||||
ping,
|
||||
sandbox_files,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
@ -66,7 +64,6 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@ -118,7 +115,6 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@ -201,7 +197,6 @@ __all__ = [
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"sandbox_files",
|
||||
"sandbox_providers",
|
||||
"saved_message",
|
||||
"setup",
|
||||
@ -216,7 +211,6 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from flask import request
|
||||
@ -28,7 +27,6 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Workflow
|
||||
from models.model import IconType
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -37,11 +35,6 @@ from services.feature_service import FeatureService
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
class RuntimeType(StrEnum):
|
||||
CLASSIC = "classic"
|
||||
SANDBOXED = "sandboxed"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
@ -106,11 +99,6 @@ class AppExportQuery(BaseModel):
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppExportBundleQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
@ -336,7 +324,6 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
runtime_type: RuntimeType = RuntimeType.CLASSIC
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -468,7 +455,6 @@ class AppListApi(Resource):
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
sandbox_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
@ -486,10 +472,6 @@ class AppListApi(Resource):
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
# Check sandbox feature
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
sandbox_app_ids.add(str(workflow.app_id))
|
||||
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
@ -500,7 +482,6 @@ class AppListApi(Resource):
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
app.runtime_type = RuntimeType.SANDBOXED if str(app.id) in sandbox_app_ids else RuntimeType.CLASSIC
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
@ -669,36 +650,6 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export-bundle")
|
||||
class AppExportBundleApi(Resource):
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
from io import BytesIO
|
||||
|
||||
from flask import send_file
|
||||
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
args = AppExportBundleQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
result = AppBundleService.export_bundle(
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
|
||||
return send_file(
|
||||
BytesIO(result.zip_bytes),
|
||||
mimetype="application/zip",
|
||||
as_attachment=True,
|
||||
download_name=result.filename,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -4,12 +4,12 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppAssetFileRequiredError,
|
||||
AppAssetNodeNotFoundError,
|
||||
AppAssetPathConflictError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_asset_entities import BatchUploadNode
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -47,26 +47,6 @@ class CreateFilePayload(BaseModel):
|
||||
return v or None
|
||||
|
||||
|
||||
class GetUploadUrlPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
size: int = Field(..., ge=0)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class BatchUploadPayload(BaseModel):
|
||||
children: list[BatchUploadNode] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class UpdateFileContentPayload(BaseModel):
|
||||
content: str
|
||||
|
||||
@ -89,9 +69,6 @@ def reg(cls: type[BaseModel]) -> None:
|
||||
|
||||
reg(CreateFolderPayload)
|
||||
reg(CreateFilePayload)
|
||||
reg(GetUploadUrlPayload)
|
||||
reg(BatchUploadNode)
|
||||
reg(BatchUploadPayload)
|
||||
reg(UpdateFileContentPayload)
|
||||
reg(RenameNodePayload)
|
||||
reg(MoveNodePayload)
|
||||
@ -130,6 +107,31 @@ class AppAssetFolderResource(Resource):
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files")
|
||||
class AppAssetFileResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file:
|
||||
raise AppAssetFileRequiredError()
|
||||
|
||||
payload = CreateFilePayload.model_validate(request.form.to_dict())
|
||||
content = file.read()
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_file(app_model, current_user.id, payload.name, content, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
|
||||
class AppAssetFileDetailResource(Resource):
|
||||
@setup_required
|
||||
@ -241,6 +243,22 @@ class AppAssetNodeReorderResource(Resource):
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/publish")
|
||||
class AppAssetPublishResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
published = AppAssetService.publish(app_model, current_user.id)
|
||||
return {
|
||||
"id": published.id,
|
||||
"version": published.version,
|
||||
"asset_tree": published.asset_tree.model_dump(),
|
||||
}, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
|
||||
class AppAssetFileDownloadUrlResource(Resource):
|
||||
@setup_required
|
||||
@ -254,68 +272,3 @@ class AppAssetFileDownloadUrlResource(Resource):
|
||||
return {"download_url": download_url}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
|
||||
class AppAssetFileUploadUrlResource(Resource):
|
||||
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node, upload_url = AppAssetService.get_file_upload_url(
|
||||
app_model, current_user.id, payload.name, payload.size, payload.parent_id
|
||||
)
|
||||
return {"node": node.model_dump(), "upload_url": upload_url}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
|
||||
class AppAssetBatchUploadResource(Resource):
|
||||
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Create nodes from tree structure and return upload URLs.
|
||||
|
||||
Input:
|
||||
{
|
||||
"children": [
|
||||
{"name": "folder1", "node_type": "folder", "children": [
|
||||
{"name": "file1.txt", "node_type": "file", "size": 1024}
|
||||
]},
|
||||
{"name": "root.txt", "node_type": "file", "size": 512}
|
||||
]
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"children": [
|
||||
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
|
||||
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
|
||||
]},
|
||||
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
result_children = AppAssetService.batch_create_from_tree(app_model, current_user.id, payload.children)
|
||||
return {"children": [child.model_dump() for child in result_children]}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
@ -51,14 +51,6 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
class AppImportBundlePayload(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
@ -147,55 +139,3 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports-bundle")
|
||||
class AppImportBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
from flask import request
|
||||
|
||||
from core.app.entities.app_bundle_entities import BundleFormatError
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if "file" not in request.files:
|
||||
return {"error": "No file provided"}, 400
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
return {"error": "Invalid file format, expected .zip"}, 400
|
||||
|
||||
zip_bytes = file.read()
|
||||
|
||||
form_data = request.form.to_dict()
|
||||
args = AppImportBundlePayload.model_validate(form_data)
|
||||
|
||||
try:
|
||||
result = AppBundleService.import_bundle(
|
||||
account=current_user,
|
||||
zip_bytes=zip_bytes,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
except BundleFormatError as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 404
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_sync"
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 409
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
|
||||
@ -16,11 +16,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.context_models import (
|
||||
AvailableVarPayload,
|
||||
CodeContextPayload,
|
||||
ParameterInfoPayload,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
@ -63,21 +58,22 @@ class InstructionTemplatePayload(BaseModel):
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
code_context: CodeContextPayload | None = Field(
|
||||
default=None, description="Existing code node context for incremental generation"
|
||||
)
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
@ -86,8 +82,6 @@ class SuggestedQuestionsPayload(BaseModel):
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
@ -334,15 +328,17 @@ class ContextGenerateApi(Resource):
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=deserialize_prompt_messages(args.prompt_messages),
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=args.model_config_data,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
code_context=args.code_context,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -366,12 +362,14 @@ class SuggestedQuestionsApi(Resource):
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -32,10 +32,8 @@ from core.trigger.debug.event_selectors import (
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@ -45,7 +43,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -185,14 +182,6 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -225,8 +214,6 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
reg(NestedNodeGraphPayload)
|
||||
@ -495,7 +482,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -533,7 +520,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -699,14 +686,13 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = AppBundleService.publish(
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
@ -817,31 +803,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@ -1050,7 +1011,6 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
@ -1199,7 +1159,6 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
@ -1270,32 +1229,3 @@ class NestedNodeGraphApi(Resource):
|
||||
response = service.generate_nested_node_graph(tenant_id=app_model.tenant_id, request=request)
|
||||
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@ -11,10 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -64,7 +61,6 @@ console_ns.schema_model(
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
@ -103,33 +99,3 @@ class WorkflowAppLogApi(Resource):
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc("get_workflow_archived_logs")
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@ -1,317 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import account_with_role_fields
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.model(
|
||||
"WorkflowCommentMentionUsers",
|
||||
{"users": fields.List(fields.Nested(account_with_role_fields))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_mention_users_model)
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"users": members}
|
||||
@ -22,8 +22,8 @@ from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment,
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
@ -44,16 +44,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@ -62,14 +52,6 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@ -270,7 +252,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
def delete(self, app_model: App):
|
||||
# FIXME(Mairuis): move to SandboxArtifactService
|
||||
current_user, _ = current_account_with_tenant()
|
||||
SandboxManager.delete_draft_storage(app_model.tenant_id, current_user.id)
|
||||
SandboxManager.delete_storage(app_model.tenant_id, current_user.id)
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -407,7 +389,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@ -500,34 +482,6 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@ -579,31 +533,3 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
@ -22,17 +19,14 @@ from fields.workflow_run_fields import (
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -99,15 +93,6 @@ workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_export_fields = console_ns.model(
|
||||
"WorkflowRunExport",
|
||||
{
|
||||
"status": fields.String(description="Export status: success/failed"),
|
||||
"presigned_url": fields.String(description="Pre-signed URL for download", required=False),
|
||||
"presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
|
||||
},
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@ -196,56 +181,6 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/export")
|
||||
class WorkflowRunExportApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_export_url")
|
||||
@console_ns.doc(description="Generate a download URL for an archived workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Export URL generated", workflow_run_export_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: str):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
select(WorkflowArchiveLog.run_created_at)
|
||||
.where(
|
||||
WorkflowArchiveLog.tenant_id == tenant_id,
|
||||
WorkflowArchiveLog.app_id == app_id,
|
||||
WorkflowArchiveLog.workflow_run_id == run_id_str,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not run_created_at:
|
||||
return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
|
||||
|
||||
prefix = (
|
||||
f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
|
||||
f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
|
||||
)
|
||||
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
|
||||
|
||||
try:
|
||||
archive_storage = get_archive_storage()
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
return {"code": "archive_storage_not_configured", "message": str(e)}, 500
|
||||
|
||||
presigned_url = archive_storage.generate_presigned_url(
|
||||
archive_key,
|
||||
expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
|
||||
return {
|
||||
"status": "success",
|
||||
"presigned_url": presigned_url,
|
||||
"presigned_url_expires_at": expires_at.isoformat(),
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
@ -40,21 +39,5 @@ class SystemFeatureApi(Resource):
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
"""Get system-wide feature configuration"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.fastopenapi import console_router
|
||||
from . import console_ns
|
||||
|
||||
|
||||
class PingResponse(BaseModel):
|
||||
result: str = Field(description="Health check result", examples=["pong"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/ping",
|
||||
response_model=PingResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def ping() -> PingResponse:
|
||||
"""Health check endpoint for connection testing."""
|
||||
return PingResponse(result="pong")
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@console_ns.doc("health_check")
|
||||
@console_ns.doc(description="Health check endpoint for connection testing")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Health check endpoint for connection testing"""
|
||||
return {"result": "pong"}
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_file_service import SandboxFileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxFileListQuery(BaseModel):
|
||||
path: str | None = Field(default=None, description="Workspace relative path")
|
||||
recursive: bool = Field(default=False, description="List recursively")
|
||||
|
||||
|
||||
class SandboxFileDownloadRequest(BaseModel):
|
||||
path: str = Field(..., description="Workspace relative file path")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SandboxFileListQuery.__name__,
|
||||
SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
SandboxFileDownloadRequest.__name__,
|
||||
SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_FILE_NODE_FIELDS = {
|
||||
"path": fields.String,
|
||||
"is_dir": fields.Boolean,
|
||||
"size": fields.Raw,
|
||||
"mtime": fields.Raw,
|
||||
}
|
||||
|
||||
|
||||
SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = {
|
||||
"download_url": fields.String,
|
||||
"expires_in": fields.Integer,
|
||||
"export_id": fields.String,
|
||||
}
|
||||
|
||||
|
||||
sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS)
|
||||
sandbox_file_download_ticket_model = console_ns.model(
|
||||
"SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/sandboxes/<string:sandbox_id>/files")
|
||||
class SandboxFilesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
|
||||
@console_ns.marshal_list_with(sandbox_file_node_model)
|
||||
def get(self, sandbox_id: str):
|
||||
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return [
|
||||
e.__dict__
|
||||
for e in SandboxFileService.list_files(
|
||||
tenant_id=tenant_id,
|
||||
sandbox_id=sandbox_id,
|
||||
path=args.path,
|
||||
recursive=args.recursive,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@console_ns.route("/sandboxes/<string:sandbox_id>/files/download")
|
||||
class SandboxFileDownloadApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
|
||||
@console_ns.marshal_with(sandbox_file_download_ticket_model)
|
||||
def post(self, sandbox_id: str):
|
||||
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
res = SandboxFileService.download_file(tenant_id=tenant_id, sandbox_id=sandbox_id, path=payload.path)
|
||||
return res.__dict__
|
||||
@ -1,19 +1,20 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SetupRequestPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Admin email address")
|
||||
@ -27,66 +28,78 @@ class SetupRequestPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class SetupStatusResponse(BaseModel):
|
||||
step: Literal["not_started", "finished"] = Field(description="Setup step status")
|
||||
setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
result: str = Field(description="Setup result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/setup",
|
||||
response_model=SetupStatusResponse,
|
||||
tags=["console"],
|
||||
console_ns.schema_model(
|
||||
SetupRequestPayload.__name__,
|
||||
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def get_setup_status_api() -> SetupStatusResponse:
|
||||
"""Get system setup status."""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
|
||||
if setup_status:
|
||||
return SetupStatusResponse(step="finished")
|
||||
return SetupStatusResponse(step="not_started")
|
||||
return SetupStatusResponse(step="finished")
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/setup",
|
||||
response_model=SetupResponse,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
"""Initialize system setup with admin account."""
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=payload.name,
|
||||
password=payload.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=payload.language,
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
@console_ns.doc("get_setup_status")
|
||||
@console_ns.doc(description="Get system setup status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SetupStatusResponse",
|
||||
{
|
||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
||||
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
|
||||
},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system setup status"""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
# Check if setup_status is a DifySetup object rather than a bool
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||
elif setup_status:
|
||||
return {"step": "finished"}
|
||||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
|
||||
return SetupResponse(result="success")
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Initialize system setup with admin account"""
|
||||
# is set up
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
@ -1,11 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
from . import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,60 +18,68 @@ class VersionQuery(BaseModel):
|
||||
current_version: str = Field(..., description="Current application version")
|
||||
|
||||
|
||||
class VersionFeatures(BaseModel):
|
||||
can_replace_logo: bool = Field(description="Whether logo replacement is supported")
|
||||
model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
version: str = Field(description="Latest version number")
|
||||
release_date: str = Field(description="Release date of latest version")
|
||||
release_notes: str = Field(description="Release notes for latest version")
|
||||
can_auto_update: bool = Field(description="Whether auto-update is supported")
|
||||
features: VersionFeatures = Field(description="Feature flags and capabilities")
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/version",
|
||||
response_model=VersionResponse,
|
||||
tags=["console"],
|
||||
console_ns.schema_model(
|
||||
VersionQuery.__name__,
|
||||
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
def check_version_update(query: VersionQuery) -> VersionResponse:
|
||||
"""Check for application version updates."""
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = VersionResponse(
|
||||
version=dify_config.project.version,
|
||||
release_date="",
|
||||
release_notes="",
|
||||
can_auto_update=False,
|
||||
features=VersionFeatures(
|
||||
can_replace_logo=dify_config.CAN_REPLACE_LOGO,
|
||||
model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"VersionResponse",
|
||||
{
|
||||
"version": fields.String(description="Latest version number"),
|
||||
"release_date": fields.String(description="Release date of latest version"),
|
||||
"release_notes": fields.String(description="Release notes for latest version"),
|
||||
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
|
||||
"features": fields.Raw(description="Feature flags and capabilities"),
|
||||
},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
if not check_update_url:
|
||||
return result
|
||||
result = {
|
||||
"version": dify_config.project.version,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
"features": {
|
||||
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
|
||||
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": query.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
content = response.json()
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result.version = query.current_version
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.current_version
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
return result
|
||||
latest_version = content.get("version", result.version)
|
||||
if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
|
||||
result.version = latest_version
|
||||
result.release_date = content.get("releaseDate", "")
|
||||
result.release_notes = content.get("releaseNotes", "")
|
||||
result.can_auto_update = content.get("canAutoUpdate", False)
|
||||
return result
|
||||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
|
||||
@ -36,7 +36,6 @@ from controllers.console.wraps import (
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -74,10 +73,6 @@ class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountAvatarQuery(BaseModel):
|
||||
avatar: str = Field(..., description="Avatar file ID")
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@ -163,7 +158,6 @@ def reg(cls: type[BaseModel]):
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
@ -254,18 +248,6 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
|
||||
console_router = FlaskRouter()
|
||||
@ -14,28 +14,15 @@ api = ExternalApi(
|
||||
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import (
|
||||
app_assets_download,
|
||||
app_assets_upload,
|
||||
image_preview,
|
||||
sandbox_archive,
|
||||
sandbox_file_downloads,
|
||||
storage_download,
|
||||
tool_files,
|
||||
upload,
|
||||
)
|
||||
from . import image_preview, storage_download, tool_files, upload
|
||||
|
||||
api.add_namespace(files_ns)
|
||||
|
||||
__all__ = [
|
||||
"api",
|
||||
"app_assets_download",
|
||||
"app_assets_upload",
|
||||
"bp",
|
||||
"files_ns",
|
||||
"image_preview",
|
||||
"sandbox_archive",
|
||||
"sandbox_file_downloads",
|
||||
"storage_download",
|
||||
"tool_files",
|
||||
"upload",
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.app_assets.storage import AppAssetSigner, AssetPath, app_asset_storage
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppAssetDownloadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
AppAssetDownloadQuery.__name__,
|
||||
AppAssetDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/download")
|
||||
@files_ns.route(
|
||||
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/download"
|
||||
)
|
||||
class AppAssetDownloadApi(Resource):
|
||||
def get(
|
||||
self,
|
||||
asset_type: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
):
|
||||
args = AppAssetDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
asset_path = AssetPath.from_components(
|
||||
asset_type=asset_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
resource_id=resource_id,
|
||||
sub_resource_id=sub_resource_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not AppAssetSigner.verify_download_signature(
|
||||
asset_path=asset_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
storage_key = app_asset_storage.get_storage_key(asset_path)
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(storage_key)
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("File not found") from exc
|
||||
|
||||
encoded_filename = quote(storage_key.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
@ -1,60 +0,0 @@
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.app_assets.storage import AppAssetSigner, AssetPath, app_asset_storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppAssetUploadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
AppAssetUploadQuery.__name__,
|
||||
AppAssetUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/upload")
|
||||
@files_ns.route(
|
||||
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/upload"
|
||||
)
|
||||
class AppAssetUploadApi(Resource):
|
||||
def put(
|
||||
self,
|
||||
asset_type: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
):
|
||||
args = AppAssetUploadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
asset_path = AssetPath.from_components(
|
||||
asset_type=asset_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
resource_id=resource_id,
|
||||
sub_resource_id=sub_resource_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not AppAssetSigner.verify_upload_signature(
|
||||
asset_path=asset_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
content = request.get_data()
|
||||
app_asset_storage.save(asset_path, content)
|
||||
return Response(status=204)
|
||||
@ -1,76 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxArchiveQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
SandboxArchiveQuery.__name__,
|
||||
SandboxArchiveQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/download")
|
||||
class SandboxArchiveDownloadApi(Resource):
|
||||
def get(self, tenant_id: str, sandbox_id: str):
|
||||
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxArchiveSigner.verify_download_signature(
|
||||
archive_path=archive_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(archive_path.get_storage_key())
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("Archive not found") from exc
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/gzip",
|
||||
direct_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/upload")
|
||||
class SandboxArchiveUploadApi(Resource):
|
||||
def put(self, tenant_id: str, sandbox_id: str):
|
||||
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxArchiveSigner.verify_upload_signature(
|
||||
archive_path=archive_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
storage.save(archive_path.get_storage_key(), request.get_data())
|
||||
return Response(status=204)
|
||||
@ -1,96 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxFileDownloadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
SandboxFileDownloadQuery.__name__,
|
||||
SandboxFileDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route(
|
||||
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/download"
|
||||
)
|
||||
class SandboxFileDownloadDownloadApi(Resource):
|
||||
def get(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
|
||||
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
export_path = SandboxFileDownloadPath(
|
||||
tenant_id=UUID(tenant_id),
|
||||
sandbox_id=UUID(sandbox_id),
|
||||
export_id=export_id,
|
||||
filename=filename,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxFileSigner.verify_download_signature(
|
||||
export_path=export_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(export_path.get_storage_key())
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("File not found") from exc
|
||||
|
||||
encoded_filename = quote(filename.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route(
|
||||
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/upload"
|
||||
)
|
||||
class SandboxFileDownloadUploadApi(Resource):
|
||||
def put(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
|
||||
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
export_path = SandboxFileDownloadPath(
|
||||
tenant_id=UUID(tenant_id),
|
||||
sandbox_id=UUID(sandbox_id),
|
||||
export_id=export_id,
|
||||
filename=filename,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxFileSigner.verify_upload_signature(
|
||||
export_path=export_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
storage.save(export_path.get_storage_key(), request.get_data())
|
||||
return Response(status=204)
|
||||
@ -261,6 +261,17 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
@ -269,18 +280,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@ -371,6 +370,17 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
@ -379,18 +389,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
||||
@ -11,9 +11,7 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
@ -24,13 +22,7 @@ class MetadataUpdatePayload(BaseModel):
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
DocumentMetadataOperation,
|
||||
MetadataOperationData,
|
||||
)
|
||||
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
|
||||
@ -17,15 +17,5 @@ class SystemFeatureApi(Resource):
|
||||
|
||||
Returns:
|
||||
dict: System feature configuration object
|
||||
|
||||
This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py,
|
||||
except it is intended for use by the web app, instead of the console dashboard.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for webapp initialization.
|
||||
|
||||
Authentication would create circular dependency (can't authenticate without webapp loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, reqparse
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@ -20,7 +18,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr
|
||||
from libs.helper import email
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -32,35 +30,10 @@ from services.app_service import AppService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class EmailCodeLoginSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class EmailCodeLoginVerifyPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@web_ns.expect(web_ns.models[LoginPayload.__name__])
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("web_app_login")
|
||||
@ -77,10 +50,15 @@ class LoginApi(Resource):
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
@ -167,7 +145,6 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
@ -176,14 +153,19 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if payload.language == "zh-Hans":
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(payload.email)
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -197,7 +179,6 @@ class EmailCodeLoginApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
@ -208,11 +189,17 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = payload.email.lower()
|
||||
user_email = args["email"].lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -223,10 +210,10 @@ class EmailCodeLoginApi(Resource):
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
@ -29,22 +27,19 @@ from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@web_ns.doc("Run Workflow")
|
||||
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
@ -63,8 +58,12 @@ class WorkflowRunApi(WebApiResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -15,9 +13,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@ -313,7 +308,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: LoopNodeRunPayload,
|
||||
args: Mapping,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
@ -329,7 +324,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.inputs is None:
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
@ -347,7 +342,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
@ -186,9 +186,6 @@ class StreamEventBuffer:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
|
||||
# Remove from map after result is recorded, so that subsequent calls
|
||||
# with the same tool_call_id are treated as new tool calls
|
||||
del self._tool_call_id_map[tool_call_id]
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
|
||||
@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -46,9 +44,6 @@ from models.workflow_features import WorkflowFeatures
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -390,7 +385,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: LoopNodeRunPayload,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
@ -406,7 +401,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.inputs is None:
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
@ -422,7 +417,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
||||
@ -25,7 +25,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -54,6 +53,7 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -166,22 +166,18 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
@ -318,6 +314,44 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
@ -20,13 +20,16 @@ class AppAssetNode(BaseModel):
|
||||
order: int = Field(default=0, description="Sort order within parent folder, lower values first")
|
||||
extension: str = Field(default="", description="File extension without dot, empty for folders")
|
||||
size: int = Field(default=0, description="File size in bytes, 0 for folders")
|
||||
checksum: str = Field(default="", description="SHA-256 checksum of file content, empty for folders")
|
||||
|
||||
@classmethod
|
||||
def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode:
|
||||
return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id)
|
||||
|
||||
@classmethod
|
||||
def create_file(cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0) -> AppAssetNode:
|
||||
def create_file(
|
||||
cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0, checksum: str = ""
|
||||
) -> AppAssetNode:
|
||||
return cls(
|
||||
id=node_id,
|
||||
node_type=AssetNodeType.FILE,
|
||||
@ -34,6 +37,7 @@ class AppAssetNode(BaseModel):
|
||||
parent_id=parent_id,
|
||||
extension=name.rsplit(".", 1)[-1] if "." in name else "",
|
||||
size=size,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
|
||||
@ -44,39 +48,10 @@ class AppAssetNodeView(BaseModel):
|
||||
path: str = Field(description="Full path from root, e.g. '/folder/file.txt'")
|
||||
extension: str = Field(default="", description="File extension without dot")
|
||||
size: int = Field(default=0, description="File size in bytes")
|
||||
checksum: str = Field(default="", description="SHA-256 checksum of file content")
|
||||
children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders")
|
||||
|
||||
|
||||
class BatchUploadNode(BaseModel):
|
||||
"""Structure for batch upload_url tree nodes, used for both input and output."""
|
||||
|
||||
name: str
|
||||
node_type: AssetNodeType
|
||||
size: int = 0
|
||||
children: list[BatchUploadNode] = []
|
||||
id: str | None = None
|
||||
upload_url: str | None = None
|
||||
|
||||
def to_app_asset_nodes(self, parent_id: str | None = None) -> list[AppAssetNode]:
|
||||
"""
|
||||
Generate IDs and convert to AppAssetNode list.
|
||||
Mutates self to set id field.
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
self.id = str(uuid4())
|
||||
nodes: list[AppAssetNode] = []
|
||||
|
||||
if self.node_type == AssetNodeType.FOLDER:
|
||||
nodes.append(AppAssetNode.create_folder(self.id, self.name, parent_id))
|
||||
for child in self.children:
|
||||
nodes.extend(child.to_app_asset_nodes(self.id))
|
||||
else:
|
||||
nodes.append(AppAssetNode.create_file(self.id, self.name, parent_id, self.size))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
class TreeNodeNotFoundError(Exception):
|
||||
"""Tree internal: node not found"""
|
||||
|
||||
@ -217,11 +192,12 @@ class AppAssetFileTree(BaseModel):
|
||||
self.nodes.append(node)
|
||||
return node
|
||||
|
||||
def update(self, node_id: str, size: int) -> AppAssetNode:
|
||||
def update(self, node_id: str, size: int, checksum: str) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node or node.node_type != AssetNodeType.FILE:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
node.size = size
|
||||
node.checksum = checksum
|
||||
return node
|
||||
|
||||
def rename(self, node_id: str, new_name: str) -> AppAssetNode:
|
||||
@ -308,6 +284,7 @@ class AppAssetFileTree(BaseModel):
|
||||
path=path,
|
||||
extension=node.extension,
|
||||
size=node.size,
|
||||
checksum=node.checksum,
|
||||
children=child_views,
|
||||
)
|
||||
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Constants
|
||||
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
|
||||
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
# Exceptions
|
||||
class BundleFormatError(Exception):
|
||||
"""Raised when bundle format is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ZipSecurityError(Exception):
|
||||
"""Raised when zip file contains security violations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Entities
|
||||
class BundleExportResult(BaseModel):
|
||||
zip_bytes: bytes = Field(description="ZIP file content as bytes")
|
||||
filename: str = Field(description="Suggested filename for the ZIP")
|
||||
|
||||
|
||||
class SourceFileEntry(BaseModel):
|
||||
path: str = Field(description="File path within the ZIP")
|
||||
node_id: str = Field(description="Node ID in the asset tree")
|
||||
|
||||
|
||||
class ExtractedFile(BaseModel):
|
||||
path: str = Field(description="Relative path of the extracted file")
|
||||
content: bytes = Field(description="File content as bytes")
|
||||
|
||||
|
||||
class ExtractedFolder(BaseModel):
|
||||
path: str = Field(description="Relative path of the extracted folder")
|
||||
@ -1,3 +0,0 @@
|
||||
from .node_factory import DifyNodeFactory
|
||||
|
||||
__all__ = ["DifyNodeFactory"]
|
||||
@ -4,8 +4,9 @@ from .entities import (
|
||||
FileAsset,
|
||||
SkillAsset,
|
||||
)
|
||||
from .packager import AssetPackager, AssetZipPackager
|
||||
from .packager import AssetPackager, ZipPackager
|
||||
from .parser import AssetItemParser, AssetParser, FileAssetParser, SkillAssetParser
|
||||
from .paths import AssetPaths
|
||||
|
||||
__all__ = [
|
||||
"AppAssetsAttrs",
|
||||
@ -13,9 +14,10 @@ __all__ = [
|
||||
"AssetItemParser",
|
||||
"AssetPackager",
|
||||
"AssetParser",
|
||||
"AssetZipPackager",
|
||||
"AssetPaths",
|
||||
"FileAsset",
|
||||
"FileAssetParser",
|
||||
"SkillAsset",
|
||||
"SkillAssetParser",
|
||||
"ZipPackager",
|
||||
]
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem, FileAsset
|
||||
from core.app_assets.storage import AppAssetStorage, AssetPath
|
||||
from core.app_assets.paths import AssetPaths
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
|
||||
class FileBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
_storage: AppAssetStorage
|
||||
|
||||
def __init__(self, storage: AppAssetStorage) -> None:
|
||||
def __init__(self) -> None:
|
||||
self._nodes = []
|
||||
self._storage = storage
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return True
|
||||
@ -26,7 +24,7 @@ class FileBuilder:
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=self._storage.get_storage_key(AssetPath.draft(ctx.tenant_id, ctx.app_id, node.id)),
|
||||
storage_key=AssetPaths.draft_file(ctx.tenant_id, ctx.app_id, node.id),
|
||||
)
|
||||
for node, path in self._nodes
|
||||
]
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.builder.file_builder import FileBuilder
|
||||
from core.app_assets.builder.skill_builder import SkillBuilder
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
from .base import AssetBuilder, BuildContext
|
||||
@ -7,8 +9,8 @@ from .base import AssetBuilder, BuildContext
|
||||
class AssetBuildPipeline:
|
||||
_builders: list[AssetBuilder]
|
||||
|
||||
def __init__(self, builders: list[AssetBuilder]) -> None:
|
||||
self._builders = builders
|
||||
def __init__(self, builders: list[AssetBuilder] | None = None) -> None:
|
||||
self._builders = builders or [SkillBuilder(), FileBuilder()]
|
||||
|
||||
def build_all(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
# 1. Distribute: each node goes to first accepting builder
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem, FileAsset
|
||||
from core.app_assets.storage import AppAssetStorage, AssetPath, AssetPathBase
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.app_assets.paths import AssetPaths
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
@ -19,27 +18,24 @@ class _LoadedSkill:
|
||||
node: AppAssetNode
|
||||
path: str
|
||||
content: str
|
||||
metadata: dict[str, Any]
|
||||
metadata: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompiledSkill:
|
||||
node: AppAssetNode
|
||||
path: str
|
||||
ref: AssetPathBase
|
||||
storage_key: str
|
||||
resolved_key: str
|
||||
content_bytes: bytes
|
||||
|
||||
|
||||
class SkillBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
_max_workers: int
|
||||
_storage: AppAssetStorage
|
||||
|
||||
def __init__(self, storage: AppAssetStorage, max_workers: int = 8) -> None:
|
||||
def __init__(self, max_workers: int = 8) -> None:
|
||||
self._nodes = []
|
||||
self._max_workers = max_workers
|
||||
self._storage = storage
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return node.extension == "md"
|
||||
@ -49,8 +45,6 @@ class SkillBuilder:
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
if not self._nodes:
|
||||
bundle = SkillBundle(assets_id=ctx.build_id)
|
||||
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
|
||||
return []
|
||||
|
||||
# 1. Load all skills (parallel IO)
|
||||
@ -68,13 +62,12 @@ class SkillBuilder:
|
||||
artifact = artifact_set.get(skill.node.id)
|
||||
if artifact is None:
|
||||
continue
|
||||
resolved_ref = AssetPath.resolved(ctx.tenant_id, ctx.app_id, ctx.build_id, skill.node.id)
|
||||
resolved_key = AssetPaths.build_resolved_file(ctx.tenant_id, ctx.app_id, ctx.build_id, skill.node.id)
|
||||
to_upload.append(
|
||||
_CompiledSkill(
|
||||
node=skill.node,
|
||||
path=skill.path,
|
||||
ref=resolved_ref,
|
||||
storage_key=self._storage.get_storage_key(resolved_ref),
|
||||
resolved_key=resolved_key,
|
||||
content_bytes=artifact.content.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
@ -89,26 +82,19 @@ class SkillBuilder:
|
||||
path=s.path,
|
||||
file_name=s.node.name,
|
||||
extension=s.node.extension or "",
|
||||
storage_key=s.storage_key,
|
||||
storage_key=s.resolved_key,
|
||||
)
|
||||
for s in to_upload
|
||||
]
|
||||
|
||||
def _load_all(self, ctx: BuildContext) -> list[_LoadedSkill]:
|
||||
def load_one(node: AppAssetNode, path: str) -> _LoadedSkill:
|
||||
draft_key = AssetPaths.draft_file(ctx.tenant_id, ctx.app_id, node.id)
|
||||
try:
|
||||
draft_ref = AssetPath.draft(ctx.tenant_id, ctx.app_id, node.id)
|
||||
data = json.loads(self._storage.load(draft_ref))
|
||||
content = ""
|
||||
metadata: dict[str, Any] = {}
|
||||
if isinstance(data, dict):
|
||||
data_dict = cast(dict[str, Any], data)
|
||||
content_value = data_dict.get("content", "")
|
||||
content = content_value if isinstance(content_value, str) else str(content_value)
|
||||
metadata_value = data_dict.get("metadata", {})
|
||||
if isinstance(metadata_value, dict):
|
||||
metadata = cast(dict[str, Any], metadata_value)
|
||||
except (FileNotFoundError, json.JSONDecodeError, TypeError, ValueError):
|
||||
data = json.loads(storage.load_once(draft_key))
|
||||
content = data.get("content", "") if isinstance(data, dict) else ""
|
||||
metadata = data.get("metadata", {}) if isinstance(data, dict) else {}
|
||||
except Exception:
|
||||
content = ""
|
||||
metadata = {}
|
||||
return _LoadedSkill(node=node, path=path, content=content, metadata=metadata)
|
||||
@ -119,7 +105,7 @@ class SkillBuilder:
|
||||
|
||||
def _upload_all(self, skills: list[_CompiledSkill]) -> None:
|
||||
def upload_one(skill: _CompiledSkill) -> None:
|
||||
self._storage.save(skill.ref, skill.content_bytes)
|
||||
storage.save(skill.resolved_key, skill.content_bytes)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
|
||||
futures = [executor.submit(upload_one, skill) for skill in skills]
|
||||
|
||||
@ -5,4 +5,3 @@ from libs.attr_map import AttrKey
|
||||
class AppAssetsAttrs:
|
||||
# Skill artifact set
|
||||
FILE_TREE = AttrKey("file_tree", AppAssetFileTree)
|
||||
APP_ASSETS_ID = AttrKey("app_assets_id", str)
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
|
||||
from core.app_assets.entities import FileAsset
|
||||
from core.app_assets.entities.assets import AssetItem
|
||||
from core.app_assets.storage import AppAssetStorage, AssetPath
|
||||
|
||||
|
||||
def tree_to_asset_items(
|
||||
tree: AppAssetFileTree,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
storage: AppAssetStorage,
|
||||
) -> list[AssetItem]:
|
||||
"""
|
||||
Convert AppAssetFileTree to list of FileAsset for packaging.
|
||||
|
||||
Args:
|
||||
tree: The asset file tree to convert
|
||||
tenant_id: Tenant ID for storage key generation
|
||||
app_id: App ID for storage key generation
|
||||
storage: App asset storage for key mapping
|
||||
|
||||
Returns:
|
||||
List of FileAsset items ready for packaging
|
||||
"""
|
||||
items: list[AssetItem] = []
|
||||
for node in tree.nodes:
|
||||
if node.node_type == AssetNodeType.FILE:
|
||||
path = tree.get_path(node.id)
|
||||
asset_path = AssetPath.draft(tenant_id, app_id, node.id)
|
||||
storage_key = storage.get_storage_key(asset_path)
|
||||
items.append(
|
||||
FileAsset(
|
||||
asset_id=node.id,
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=storage_key,
|
||||
)
|
||||
)
|
||||
return items
|
||||
@ -1,7 +1,7 @@
|
||||
from .asset_zip_packager import AssetZipPackager
|
||||
from .base import AssetPackager
|
||||
from .zip_packager import ZipPackager
|
||||
|
||||
__all__ = [
|
||||
"AssetPackager",
|
||||
"AssetZipPackager",
|
||||
"ZipPackager",
|
||||
]
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
|
||||
from core.app_assets.entities import AssetItem
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class AssetZipPackager:
|
||||
"""
|
||||
Unified ZIP packager for assets.
|
||||
Automatically creates directory entries from asset paths.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: BaseStorage, *, max_workers: int = 8) -> None:
|
||||
self._storage = storage
|
||||
self._max_workers = max_workers
|
||||
|
||||
def package(self, assets: list[AssetItem], *, prefix: str = "") -> bytes:
|
||||
"""
|
||||
Package assets into a ZIP file.
|
||||
|
||||
Args:
|
||||
assets: List of assets to package
|
||||
prefix: Optional prefix to add to all paths in the ZIP
|
||||
|
||||
Returns:
|
||||
ZIP file content as bytes
|
||||
"""
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
# Extract folder paths from asset paths
|
||||
folder_paths = self._extract_folder_paths(assets, prefix)
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# Create directory entries
|
||||
for folder_path in sorted(folder_paths):
|
||||
zf.writestr(zipfile.ZipInfo(folder_path + "/"), "")
|
||||
|
||||
# Write files in parallel
|
||||
if assets:
|
||||
self._write_files_parallel(zf, assets, prefix)
|
||||
|
||||
return zip_buffer.getvalue()
|
||||
|
||||
def _extract_folder_paths(self, assets: list[AssetItem], prefix: str) -> set[str]:
|
||||
"""Extract all folder paths from asset paths."""
|
||||
folders: set[str] = set()
|
||||
for asset in assets:
|
||||
full_path = f"{prefix}/{asset.path}" if prefix else asset.path
|
||||
parts = full_path.split("/")[:-1] # Remove filename
|
||||
folders.update("/".join(parts[:i]) for i in range(1, len(parts) + 1))
|
||||
return folders
|
||||
|
||||
def _write_files_parallel(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
assets: list[AssetItem],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
lock = Lock()
|
||||
|
||||
def load_and_write(asset: AssetItem) -> None:
|
||||
content = self._storage.load_once(asset.get_storage_key())
|
||||
full_path = f"{prefix}/{asset.path}" if prefix else asset.path
|
||||
with lock:
|
||||
zf.writestr(full_path, content)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
|
||||
futures = [executor.submit(load_and_write, a) for a in assets]
|
||||
for future in futures:
|
||||
future.result()
|
||||
42
api/core/app_assets/packager/zip_packager.py
Normal file
42
api/core/app_assets/packager/zip_packager.py
Normal file
@ -0,0 +1,42 @@
|
||||
import io
|
||||
import zipfile
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
from .base import AssetPackager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_storage import Storage
|
||||
|
||||
|
||||
class ZipPackager(AssetPackager):
|
||||
_storage: "Storage"
|
||||
|
||||
def __init__(self, storage: "Storage") -> None:
|
||||
self._storage = storage
|
||||
|
||||
def package(self, assets: list[AssetItem]) -> bytes:
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
lock = Lock()
|
||||
# FOR DELVELPMENT AND TESTING ONLY, TODO: optimize
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future[None]] = []
|
||||
for asset in assets:
|
||||
|
||||
def _write_asset(a: AssetItem) -> None:
|
||||
content = self._storage.load_once(a.get_storage_key())
|
||||
with lock:
|
||||
zf.writestr(a.path, content)
|
||||
|
||||
futures.append(executor.submit(_write_asset, asset))
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
return zip_buffer.getvalue()
|
||||
@ -1,6 +1,6 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.storage import AssetPath, app_asset_storage
|
||||
from core.app_assets.paths import AssetPaths
|
||||
|
||||
from .base import AssetItemParser, FileAssetParser
|
||||
|
||||
@ -15,7 +15,7 @@ class AssetParser:
|
||||
self._tree = tree
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._parsers: dict[str, AssetItemParser] = {}
|
||||
self._parsers = {}
|
||||
self._default_parser = FileAssetParser()
|
||||
|
||||
def register(self, extension: str, parser: AssetItemParser) -> None:
|
||||
@ -26,10 +26,10 @@ class AssetParser:
|
||||
|
||||
for node in self._tree.walk_files():
|
||||
path = self._tree.get_path(node.id).lstrip("/")
|
||||
storage_key = app_asset_storage.get_storage_key(AssetPath.draft(self._tenant_id, self._app_id, node.id))
|
||||
storage_key = AssetPaths.draft_file(self._tenant_id, self._app_id, node.id)
|
||||
extension = node.extension or ""
|
||||
|
||||
parser: AssetItemParser = self._parsers.get(extension, self._default_parser)
|
||||
parser = self._parsers.get(extension, self._default_parser)
|
||||
asset = parser.parse(node.id, path, node.name, extension, storage_key)
|
||||
assets.append(asset)
|
||||
|
||||
|
||||
18
api/core/app_assets/paths.py
Normal file
18
api/core/app_assets/paths.py
Normal file
@ -0,0 +1,18 @@
|
||||
class AssetPaths:
|
||||
_BASE = "app_assets"
|
||||
|
||||
@staticmethod
|
||||
def draft_file(tenant_id: str, app_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/draft/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def build_resolved_file(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/resolved/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_skill_artifact_set(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json"
|
||||
@ -1,393 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from extensions.storage.silent_storage import SilentStorage
|
||||
from libs import rsa
|
||||
|
||||
_ASSET_BASE = "app_assets"
|
||||
_SILENT_STORAGE_NOT_FOUND = b"File Not Found"
|
||||
_PATH_TEMPLATES: dict[str, str] = {
|
||||
"draft": f"{_ASSET_BASE}/{{t}}/{{a}}/draft/{{r}}",
|
||||
"build-zip": f"{_ASSET_BASE}/{{t}}/{{a}}/artifacts/{{r}}.zip",
|
||||
"resolved": f"{_ASSET_BASE}/{{t}}/{{a}}/artifacts/{{r}}/resolved/{{s}}",
|
||||
"skill-bundle": f"{_ASSET_BASE}/{{t}}/{{a}}/artifacts/{{r}}/skill_artifact_set.json",
|
||||
"source-zip": f"{_ASSET_BASE}/{{t}}/{{a}}/sources/{{r}}.zip",
|
||||
}
|
||||
_ASSET_PATH_REGISTRY: dict[str, tuple[bool, Callable[..., AssetPathBase]]] = {}
|
||||
|
||||
|
||||
def _require_uuid(value: str, field_name: str) -> None:
|
||||
try:
|
||||
UUID(value)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(f"{field_name} must be a UUID") from exc
|
||||
|
||||
|
||||
def register_asset_path(asset_type: str, *, requires_node: bool, factory: Callable[..., AssetPathBase]) -> None:
|
||||
_ASSET_PATH_REGISTRY[asset_type] = (requires_node, factory)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetPathBase:
|
||||
asset_type: ClassVar[str]
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
resource_id: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
_require_uuid(self.tenant_id, "tenant_id")
|
||||
_require_uuid(self.app_id, "app_id")
|
||||
_require_uuid(self.resource_id, "resource_id")
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return _PATH_TEMPLATES[self.asset_type].format(
|
||||
t=self.tenant_id,
|
||||
a=self.app_id,
|
||||
r=self.resource_id,
|
||||
s=self.signature_sub_resource_id() or "",
|
||||
)
|
||||
|
||||
def signature_resource_id(self) -> str:
|
||||
return self.resource_id
|
||||
|
||||
def signature_sub_resource_id(self) -> str:
|
||||
return ""
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
parts = [self.asset_type, self.tenant_id, self.app_id, self.signature_resource_id()]
|
||||
sub_resource_id = self.signature_sub_resource_id()
|
||||
if sub_resource_id:
|
||||
parts.append(sub_resource_id)
|
||||
return parts
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _DraftAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "draft"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _BuildZipAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "build-zip"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ResolvedAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "resolved"
|
||||
node_id: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
_require_uuid(self.node_id, "node_id")
|
||||
|
||||
def signature_sub_resource_id(self) -> str:
|
||||
return self.node_id
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SkillBundleAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "skill-bundle"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SourceZipAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "source-zip"
|
||||
|
||||
|
||||
class AssetPath:
|
||||
@staticmethod
|
||||
def draft(tenant_id: str, app_id: str, node_id: str) -> AssetPathBase:
|
||||
return _DraftAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=node_id)
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
|
||||
return _BuildZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
|
||||
|
||||
@staticmethod
|
||||
def resolved(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> AssetPathBase:
|
||||
return _ResolvedAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id, node_id=node_id)
|
||||
|
||||
@staticmethod
|
||||
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
|
||||
return _SkillBundleAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
|
||||
|
||||
@staticmethod
|
||||
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> AssetPathBase:
|
||||
return _SourceZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def from_components(
|
||||
asset_type: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
) -> AssetPathBase:
|
||||
entry = _ASSET_PATH_REGISTRY.get(asset_type)
|
||||
if not entry:
|
||||
raise ValueError(f"Unsupported asset type: {asset_type}")
|
||||
requires_node, factory = entry
|
||||
if requires_node and not sub_resource_id:
|
||||
raise ValueError("resolved assets require node_id")
|
||||
if not requires_node and sub_resource_id:
|
||||
raise ValueError(f"{asset_type} assets do not accept node_id")
|
||||
if requires_node:
|
||||
return factory(tenant_id, app_id, resource_id, sub_resource_id)
|
||||
return factory(tenant_id, app_id, resource_id)
|
||||
|
||||
|
||||
register_asset_path("draft", requires_node=False, factory=AssetPath.draft)
|
||||
register_asset_path("build-zip", requires_node=False, factory=AssetPath.build_zip)
|
||||
register_asset_path("resolved", requires_node=True, factory=AssetPath.resolved)
|
||||
register_asset_path("skill-bundle", requires_node=False, factory=AssetPath.skill_bundle)
|
||||
register_asset_path("source-zip", requires_node=False, factory=AssetPath.source_zip)
|
||||
|
||||
|
||||
class AppAssetSigner:
|
||||
SIGNATURE_PREFIX = "app-asset"
|
||||
SIGNATURE_VERSION = "v1"
|
||||
OPERATION_DOWNLOAD = "download"
|
||||
OPERATION_UPLOAD = "upload"
|
||||
|
||||
@classmethod
|
||||
def create_download_signature(cls, asset_path: AssetPathBase, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_upload_signature(cls, asset_path: AssetPathBase, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_download_signature(cls, asset_path: AssetPathBase, expires_at: int, nonce: str, sign: str) -> bool:
|
||||
return cls._verify_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_upload_signature(cls, asset_path: AssetPathBase, expires_at: int, nonce: str, sign: str) -> bool:
|
||||
return cls._verify_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _verify_signature(
|
||||
cls,
|
||||
*,
|
||||
asset_path: AssetPathBase,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
if expires_at <= 0:
|
||||
return False
|
||||
|
||||
expected_sign = cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
if not hmac.compare_digest(sign, expected_sign):
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
if expires_at < current_time:
|
||||
return False
|
||||
|
||||
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _create_signature(cls, *, asset_path: AssetPathBase, operation: str, expires_at: int, nonce: str) -> str:
|
||||
key = cls._tenant_key(asset_path.tenant_id)
|
||||
message = cls._signature_message(
|
||||
asset_path=asset_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
sign = hmac.new(key, message.encode(), hashlib.sha256).digest()
|
||||
return base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@classmethod
|
||||
def _signature_message(cls, *, asset_path: AssetPathBase, operation: str, expires_at: int, nonce: str) -> str:
|
||||
sub_resource_id = asset_path.signature_sub_resource_id()
|
||||
return (
|
||||
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
|
||||
f"{asset_path.asset_type}|{asset_path.tenant_id}|{asset_path.app_id}|"
|
||||
f"{asset_path.signature_resource_id()}|{sub_resource_id}|{expires_at}|{nonce}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _tenant_key(cls, tenant_id: str) -> bytes:
|
||||
try:
|
||||
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
|
||||
except rsa.PrivkeyNotFoundError as exc:
|
||||
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
|
||||
private_key = rsa_key.export_key()
|
||||
return hashlib.sha256(private_key).digest()
|
||||
|
||||
|
||||
class AppAssetStorage:
|
||||
_base_storage: BaseStorage
|
||||
_storage: CachedPresignStorage
|
||||
|
||||
def __init__(self, storage: BaseStorage, *, redis_client: Any, cache_key_prefix: str = "app_assets") -> None:
|
||||
self._base_storage = storage
|
||||
self._storage = CachedPresignStorage(
|
||||
storage=storage,
|
||||
redis_client=redis_client,
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def storage(self) -> BaseStorage:
|
||||
return self._storage
|
||||
|
||||
def save(self, asset_path: AssetPathBase, content: bytes) -> None:
|
||||
self._storage.save(self.get_storage_key(asset_path), content)
|
||||
|
||||
def load(self, asset_path: AssetPathBase) -> bytes:
|
||||
return self._storage.load_once(self.get_storage_key(asset_path))
|
||||
|
||||
def load_or_none(self, asset_path: AssetPathBase) -> bytes | None:
|
||||
try:
|
||||
data = self._storage.load_once(self.get_storage_key(asset_path))
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if data == _SILENT_STORAGE_NOT_FOUND:
|
||||
return None
|
||||
return data
|
||||
|
||||
def delete(self, asset_path: AssetPathBase) -> None:
|
||||
self._storage.delete(self.get_storage_key(asset_path))
|
||||
|
||||
def get_storage_key(self, asset_path: AssetPathBase) -> str:
|
||||
return asset_path.get_storage_key()
|
||||
|
||||
def get_download_url(self, asset_path: AssetPathBase, expires_in: int = 3600) -> str:
|
||||
storage_key = self.get_storage_key(asset_path)
|
||||
try:
|
||||
return self._storage.get_download_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return self._generate_signed_proxy_download_url(asset_path, expires_in)
|
||||
|
||||
def get_download_urls(
|
||||
self,
|
||||
asset_paths: Iterable[AssetPathBase],
|
||||
expires_in: int = 3600,
|
||||
) -> list[str]:
|
||||
asset_paths_list = list(asset_paths)
|
||||
storage_keys = [self.get_storage_key(asset_path) for asset_path in asset_paths_list]
|
||||
|
||||
try:
|
||||
return self._storage.get_download_urls(storage_keys, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return [self._generate_signed_proxy_download_url(asset_path, expires_in) for asset_path in asset_paths_list]
|
||||
|
||||
def get_upload_url(
|
||||
self,
|
||||
asset_path: AssetPathBase,
|
||||
expires_in: int = 3600,
|
||||
) -> str:
|
||||
storage_key = self.get_storage_key(asset_path)
|
||||
try:
|
||||
return self._storage.get_upload_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return self._generate_signed_proxy_upload_url(asset_path, expires_in)
|
||||
|
||||
def _generate_signed_proxy_download_url(self, asset_path: AssetPathBase, expires_in: int) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = AppAssetSigner.create_download_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="download")
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
def _generate_signed_proxy_upload_url(self, asset_path: AssetPathBase, expires_in: int) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = AppAssetSigner.create_upload_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="upload")
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
@staticmethod
|
||||
def _build_proxy_url(*, base_url: str, asset_path: AssetPathBase, action: str) -> str:
|
||||
encoded_parts = [urllib.parse.quote(part, safe="") for part in asset_path.proxy_path_parts()]
|
||||
path = "/".join(encoded_parts)
|
||||
return f"{base_url}/files/app-assets/{path}/{action}"
|
||||
|
||||
|
||||
class _LazyAppAssetStorage:
|
||||
_instance: AppAssetStorage | None
|
||||
_cache_key_prefix: str
|
||||
|
||||
def __init__(self, *, cache_key_prefix: str) -> None:
|
||||
self._instance = None
|
||||
self._cache_key_prefix = cache_key_prefix
|
||||
|
||||
def _get_instance(self) -> AppAssetStorage:
|
||||
if self._instance is None:
|
||||
if not hasattr(storage, "storage_runner"):
|
||||
raise RuntimeError("Storage is not initialized; call storage.init_app before using app_asset_storage")
|
||||
self._instance = AppAssetStorage(
|
||||
storage=SilentStorage(storage.storage_runner),
|
||||
redis_client=redis_client,
|
||||
cache_key_prefix=self._cache_key_prefix,
|
||||
)
|
||||
return self._instance
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
return getattr(self._get_instance(), name)
|
||||
|
||||
|
||||
app_asset_storage = _LazyAppAssetStorage(cache_key_prefix="app_assets")
|
||||
@ -1,5 +0,0 @@
|
||||
from .source_zip_extractor import SourceZipExtractor
|
||||
|
||||
__all__ = [
|
||||
"SourceZipExtractor",
|
||||
]
|
||||
@ -1,98 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app.entities.app_bundle_entities import ExtractedFile, ExtractedFolder, ZipSecurityError
|
||||
from core.app_assets.storage import AssetPath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app_assets.storage import AppAssetStorage
|
||||
|
||||
|
||||
class SourceZipExtractor:
|
||||
def __init__(self, storage: AppAssetStorage) -> None:
|
||||
self._storage = storage
|
||||
|
||||
def extract_entries(
|
||||
self, zip_bytes: bytes, *, expected_prefix: str
|
||||
) -> tuple[list[ExtractedFolder], list[ExtractedFile]]:
|
||||
folders: list[ExtractedFolder] = []
|
||||
files: list[ExtractedFile] = []
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
|
||||
for info in zf.infolist():
|
||||
name = info.filename
|
||||
self._validate_path(name)
|
||||
|
||||
if not name.startswith(expected_prefix):
|
||||
continue
|
||||
|
||||
relative_path = name[len(expected_prefix) :].lstrip("/")
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
if info.is_dir():
|
||||
folders.append(ExtractedFolder(path=relative_path.rstrip("/")))
|
||||
else:
|
||||
content = zf.read(info)
|
||||
files.append(ExtractedFile(path=relative_path, content=content))
|
||||
|
||||
return folders, files
|
||||
|
||||
def build_tree_and_save(
|
||||
self,
|
||||
folders: list[ExtractedFolder],
|
||||
files: list[ExtractedFile],
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> AppAssetFileTree:
|
||||
tree = AppAssetFileTree()
|
||||
path_to_node_id: dict[str, str] = {}
|
||||
|
||||
all_folder_paths = {f.path for f in folders}
|
||||
for file in files:
|
||||
self._ensure_parent_folders(file.path, all_folder_paths)
|
||||
|
||||
sorted_folders = sorted(all_folder_paths, key=lambda p: p.count("/"))
|
||||
for folder_path in sorted_folders:
|
||||
node_id = str(uuid4())
|
||||
name = folder_path.rsplit("/", 1)[-1]
|
||||
parent_path = folder_path.rsplit("/", 1)[0] if "/" in folder_path else None
|
||||
parent_id = path_to_node_id.get(parent_path) if parent_path else None
|
||||
|
||||
node = AppAssetNode.create_folder(node_id, name, parent_id)
|
||||
tree.add(node)
|
||||
path_to_node_id[folder_path] = node_id
|
||||
|
||||
sorted_files = sorted(files, key=lambda f: f.path)
|
||||
for file in sorted_files:
|
||||
node_id = str(uuid4())
|
||||
name = file.path.rsplit("/", 1)[-1]
|
||||
parent_path = file.path.rsplit("/", 1)[0] if "/" in file.path else None
|
||||
parent_id = path_to_node_id.get(parent_path) if parent_path else None
|
||||
|
||||
node = AppAssetNode.create_file(node_id, name, parent_id, len(file.content))
|
||||
tree.add(node)
|
||||
|
||||
asset_path = AssetPath.draft(tenant_id, app_id, node_id)
|
||||
self._storage.save(asset_path, file.content)
|
||||
|
||||
return tree
|
||||
|
||||
def _validate_path(self, path: str) -> None:
|
||||
if ".." in path:
|
||||
raise ZipSecurityError(f"Path traversal detected: {path}")
|
||||
if path.startswith("/"):
|
||||
raise ZipSecurityError(f"Absolute path detected: {path}")
|
||||
if "\\" in path:
|
||||
raise ZipSecurityError(f"Backslash in path: {path}")
|
||||
|
||||
def _ensure_parent_folders(self, file_path: str, folder_set: set[str]) -> None:
|
||||
parts = file_path.split("/")[:-1]
|
||||
for i in range(1, len(parts) + 1):
|
||||
parent = "/".join(parts[:i])
|
||||
folder_set.add(parent)
|
||||
@ -1,62 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
@ -1,16 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.llm_generator.context_models import (
|
||||
AvailableVarPayload,
|
||||
CodeContextPayload,
|
||||
ParameterInfoPayload,
|
||||
)
|
||||
from core.llm_generator.output_models import (
|
||||
CodeNodeStructuredOutput,
|
||||
InstructionModifyOutput,
|
||||
@ -137,6 +132,9 @@ class LLMGenerator:
|
||||
return []
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -404,24 +402,24 @@ class LLMGenerator:
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
code_context: CodeContextPayload,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
available_vars: Client-provided available variables with types/schema
|
||||
parameter_info: Client-provided parameter metadata (type/constraints)
|
||||
code_context: Client-provided existing code node context
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
@ -432,12 +430,43 @@ class LLMGenerator:
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# available_vars/parameter_info/code_context are provided by the frontend context-generate modal.
|
||||
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
available_vars=available_vars, parameter_info=parameter_info, language=language, code_context=code_context
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
@ -475,11 +504,9 @@ class LLMGenerator:
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return {
|
||||
**response.model_dump(),
|
||||
"code_language": language,
|
||||
"error": "",
|
||||
}
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
@ -503,30 +530,55 @@ class LLMGenerator:
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
model_config: dict,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# available_vars/parameter_info are provided by the frontend context-generate modal.
|
||||
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
available_vars=available_vars,
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
UserPromptMessage(content=system_prompt),
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
@ -552,6 +604,7 @@ class LLMGenerator:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
@ -559,12 +612,13 @@ class LLMGenerator:
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=completion_params,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return {"questions": response.questions, "error": ""}
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
@ -575,101 +629,213 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
parameter_block = cls._format_parameter_info(parameter_info)
|
||||
available_vars_block = cls._format_available_vars(
|
||||
available_vars,
|
||||
max_items=30,
|
||||
max_schema_chars=400,
|
||||
max_description_chars=120,
|
||||
)
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest exactly 3 short instructions that would help generate code for the target parameter.
|
||||
|
||||
## Target Parameter
|
||||
{parameter_block}
|
||||
|
||||
## Available Variables
|
||||
{available_vars_block}
|
||||
|
||||
## Constraints
|
||||
- Output exactly 3 instructions.
|
||||
- Use {language}.
|
||||
- Keep each instruction short and practical.
|
||||
- Do not include code or variable syntax in the instructions.
|
||||
"""
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _format_parameter_info(cls, parameter_info: ParameterInfoPayload) -> str:
|
||||
payload = parameter_info.model_dump(mode="python", by_alias=True)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _format_available_vars(
|
||||
cls,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
*,
|
||||
max_items: int,
|
||||
max_schema_chars: int,
|
||||
max_description_chars: int,
|
||||
) -> str:
|
||||
payload = [item.model_dump(mode="python", by_alias=True) for item in available_vars]
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _format_code_context(cls, code_context: CodeContextPayload | None) -> str:
|
||||
if not code_context:
|
||||
return ""
|
||||
code = code_context.code
|
||||
outputs = code_context.outputs
|
||||
variables = code_context.variables
|
||||
if not code and not outputs and not variables:
|
||||
return ""
|
||||
payload = code_context.model_dump(mode="python", by_alias=True)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
code_context: CodeContextPayload,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
param_type = parameter_info.type or "string"
|
||||
parameter_block = cls._format_parameter_info(parameter_info)
|
||||
available_vars_block = cls._format_available_vars(
|
||||
available_vars,
|
||||
max_items=80,
|
||||
max_schema_chars=800,
|
||||
max_description_chars=160,
|
||||
)
|
||||
code_context_block = cls._format_code_context(code_context)
|
||||
code_context_section = f"\n{code_context_block}\n" if code_context_block else "\n"
|
||||
return f"""You are a code generator for Dify workflow automation.
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform available variables for the target parameter.
|
||||
If user is not talking about the code node, provide the existing data or blank data for user, following the schema.
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Target Parameter
|
||||
{parameter_block}
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Available Variables
|
||||
{available_vars_block}
|
||||
{code_context_section}## Requirements
|
||||
- Use only the listed value_selector paths.
|
||||
- Do not invent variables or fields that are not listed.
|
||||
- Write a main function that returns type: {param_type}.
|
||||
- Respect target constraints (options/min/max/default/multiple) if provided.
|
||||
- If existing code is provided, adapt it instead of rewriting from scratch.
|
||||
- Return only JSON that matches the provided schema.
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
# Convert outputs from array format [{name, type}] to dict format {name: {type}}
|
||||
# Array format is required for OpenAI/Azure strict JSON schema compatibility
|
||||
raw_outputs = content.get("outputs", [])
|
||||
if isinstance(raw_outputs, list):
|
||||
outputs = {
|
||||
item.get("name", "result"): {"type": item.get("type", parameter_type)}
|
||||
for item in raw_outputs
|
||||
if isinstance(item, dict) and item.get("name")
|
||||
}
|
||||
if not outputs:
|
||||
outputs = {"result": {"type": parameter_type}}
|
||||
else:
|
||||
outputs = raw_outputs or {"result": {"type": parameter_type}}
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
@ -725,7 +891,7 @@ If user is not talking about the code node, provide the existing data or blank d
|
||||
raise ValueError("Workflow not found for the given app model.")
|
||||
last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
|
||||
try:
|
||||
node_type = last_run.node_type
|
||||
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
|
||||
except Exception:
|
||||
try:
|
||||
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
|
||||
@ -844,7 +1010,7 @@ If user is not talking about the code node, provide the existing data or blank d
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
return response.model_dump(mode="python")
|
||||
return response.structured_output or {}
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
||||
@ -55,7 +55,7 @@ class CodeNodeStructuredOutput(BaseModel):
|
||||
outputs: list[CodeNodeOutputItem] = Field(
|
||||
description="Output variable definitions specifying name and type for each return value"
|
||||
)
|
||||
message: str = Field(description="Brief explanation of what the generated code does")
|
||||
explanation: str = Field(description="Brief explanation of what the generated code does")
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
|
||||
@ -26,7 +26,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
@ -44,12 +44,6 @@ class SpecialModelType(StrEnum):
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
# Tool name for structured output via tool call
|
||||
STRUCTURED_OUTPUT_TOOL_NAME = "structured_output"
|
||||
|
||||
# Features that indicate tool call support
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@ -138,24 +132,20 @@ def invoke_llm_with_structured_output(
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
# Determine structured output strategy
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
# Priority 1: Native JSON schema support
|
||||
model_parameters_with_json_schema = _handle_native_json_schema(
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
elif _supports_tool_call(model_schema):
|
||||
# Priority 2: Tool call based structured output
|
||||
structured_output_tool = _create_structured_output_tool(json_schema)
|
||||
tools = [structured_output_tool]
|
||||
else:
|
||||
# Priority 3: Prompt-based fallback
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
@ -172,11 +162,12 @@ def invoke_llm_with_structured_output(
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
# Non-streaming result
|
||||
structured_output = _extract_structured_output(llm_result)
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
@ -198,16 +189,13 @@ def invoke_llm_with_structured_output(
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: str | None = None
|
||||
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
# Collect text content
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
@ -215,13 +203,6 @@ def invoke_llm_with_structured_output(
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
# Collect tool call arguments
|
||||
if event.delta.message.tool_calls:
|
||||
for tool_call in event.delta.message.tool_calls:
|
||||
call_id = tool_call.id or ""
|
||||
if tool_call.function.arguments:
|
||||
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -229,11 +210,7 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
# Extract structured output: prefer tool call, fallback to text
|
||||
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
@ -259,6 +236,7 @@ def invoke_llm_with_structured_output(
|
||||
return generator()
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
@ -273,7 +251,24 @@ def invoke_llm_with_pydantic_model(
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> T:
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
|
||||
@ -304,7 +299,7 @@ def invoke_llm_with_pydantic_model(
|
||||
raise OutputParserError("Structured output is empty")
|
||||
|
||||
validated_output = _validate_structured_output(output_model, structured_output)
|
||||
return output_model.model_validate(validated_output)
|
||||
return result.model_copy(update={"structured_output": validated_output})
|
||||
|
||||
|
||||
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
||||
@ -314,150 +309,12 @@ def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
||||
def _validate_structured_output(
|
||||
output_model: type[T],
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> T:
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
validated_output = output_model.model_validate(structured_output)
|
||||
except ValidationError as exc:
|
||||
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
|
||||
return validated_output
|
||||
|
||||
|
||||
def _supports_tool_call(model_schema: AIModelEntity) -> bool:
|
||||
"""Check if model supports tool call feature."""
|
||||
return bool(set(model_schema.features or []) & TOOL_CALL_FEATURES)
|
||||
|
||||
|
||||
def _create_structured_output_tool(json_schema: Mapping[str, Any]) -> PromptMessageTool:
|
||||
"""Create a tool definition for structured output."""
|
||||
return PromptMessageTool(
|
||||
name=STRUCTURED_OUTPUT_TOOL_NAME,
|
||||
description="Generate structured output according to the provided schema. "
|
||||
"You MUST call this function to provide your response in the required format.",
|
||||
parameters=dict(json_schema),
|
||||
)
|
||||
|
||||
|
||||
def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
|
||||
"""
|
||||
Extract structured output from LLM result (non-streaming).
|
||||
First tries to extract from tool_calls (if present), then falls back to text content.
|
||||
"""
|
||||
# Try to extract from tool call first
|
||||
tool_calls = llm_result.message.tool_calls
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.function.name == STRUCTURED_OUTPUT_TOOL_NAME:
|
||||
return _parse_tool_call_arguments(tool_call.function.arguments)
|
||||
|
||||
# Fallback to text content parsing
|
||||
content = llm_result.message.content
|
||||
if not isinstance(content, str):
|
||||
raise OutputParserError(f"Failed to parse structured output, LLM result is not a string: {content}")
|
||||
return _parse_structured_output(content)
|
||||
|
||||
|
||||
def _extract_structured_output_from_stream(
|
||||
result_text: str,
|
||||
tool_call_args: dict[str, str],
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Extract structured output from streaming collected data.
|
||||
First tries to parse from collected tool call arguments, then falls back to text content.
|
||||
"""
|
||||
# Try to parse from tool call arguments first
|
||||
if tool_call_args:
|
||||
# Use the first non-empty tool call arguments
|
||||
for arguments in tool_call_args.values():
|
||||
if arguments.strip():
|
||||
return _parse_tool_call_arguments(arguments)
|
||||
|
||||
# Fallback to text content parsing
|
||||
if not result_text:
|
||||
raise OutputParserError("No tool call arguments and no text content to parse")
|
||||
return _parse_structured_output(result_text)
|
||||
|
||||
|
||||
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||
"""Parse JSON from tool call arguments."""
|
||||
if not arguments:
|
||||
raise OutputParserError("Tool call arguments is empty")
|
||||
|
||||
try:
|
||||
parsed = json.loads(arguments)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
# Try to repair malformed JSON
|
||||
repaired = json_repair.loads(arguments)
|
||||
if not isinstance(repaired, dict):
|
||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||
return cast(dict, repaired)
|
||||
|
||||
|
||||
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
"""Get default empty value for a JSON schema type."""
|
||||
# Handle array of types (e.g., ["string", "null"])
|
||||
if isinstance(type_name, list):
|
||||
# Use the first non-null type
|
||||
type_name = next((t for t in type_name if t != "null"), None)
|
||||
|
||||
if type_name == "string":
|
||||
return ""
|
||||
elif type_name == "object":
|
||||
return {}
|
||||
elif type_name == "array":
|
||||
return []
|
||||
elif type_name in {"number", "integer"}:
|
||||
return 0
|
||||
elif type_name == "boolean":
|
||||
return False
|
||||
elif type_name == "null" or type_name is None:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fill_defaults_from_schema(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fill missing required fields in output with default empty values based on JSON schema.
|
||||
|
||||
Only fills default values for fields that are marked as required in the schema.
|
||||
Recursively processes nested objects to fill their required fields as well.
|
||||
|
||||
Default values by type:
|
||||
- string → ""
|
||||
- object → {} (with nested required fields filled)
|
||||
- array → []
|
||||
- number/integer → 0
|
||||
- boolean → False
|
||||
- null → None
|
||||
"""
|
||||
result = dict(output)
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = set(json_schema.get("required", []))
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
prop_type = prop_schema.get("type")
|
||||
is_required = prop_name in required_fields
|
||||
|
||||
if prop_name not in result:
|
||||
# Field is missing from output
|
||||
if is_required:
|
||||
# Only fill default value for required fields
|
||||
if prop_type == "object" and "properties" in prop_schema:
|
||||
# Create empty object and recursively fill its required fields
|
||||
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
|
||||
else:
|
||||
result[prop_name] = _get_default_value_for_type(prop_type)
|
||||
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
|
||||
# Field exists and is an object, recursively fill nested required fields
|
||||
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
|
||||
|
||||
return result
|
||||
return validated_output.model_dump(mode="python")
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
|
||||
@ -304,10 +304,9 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
||||
STRUCTURED_OUTPUT_PROMPT = """You’re an AI that accepts any input but only output in JSON. You must always output in JSON format.
|
||||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- You mustn't output in plain text.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -30,142 +30,6 @@ def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
|
||||
if not callbacks:
|
||||
return
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
invoke(callback)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise
|
||||
logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
|
||||
|
||||
|
||||
def _get_or_create_tool_call(
|
||||
existing_tools_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_call_id: str,
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Get or create a tool call by ID.
|
||||
|
||||
If `tool_call_id` is empty, returns the most recently created tool call.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
if not existing_tools_calls:
|
||||
raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
|
||||
def _merge_tool_call_delta(
|
||||
tool_call: AssistantPromptMessage.ToolCall,
|
||||
delta: AssistantPromptMessage.ToolCall,
|
||||
) -> None:
|
||||
if delta.id:
|
||||
tool_call.id = delta.id
|
||||
if delta.type:
|
||||
tool_call.type = delta.type
|
||||
if delta.function.name:
|
||||
tool_call.function.name = delta.function.name
|
||||
if delta.function.arguments:
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_first_chunk(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` from the first returned chunk.
|
||||
|
||||
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint: str | None = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
model_parameters: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
@ -176,13 +40,42 @@ def _increase_tool_call(
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
def get_tool_call(tool_call_id: str):
|
||||
"""
|
||||
Get or create a tool call by ID
|
||||
|
||||
:param tool_call_id: tool call ID
|
||||
:return: existing or new tool call
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||
if _tool_call is None:
|
||||
_tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(_tool_call)
|
||||
|
||||
return _tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
|
||||
tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
|
||||
_merge_tool_call_delta(tool_call, new_tool_call)
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.id)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
@ -248,7 +141,10 @@ class LargeLanguageModel(AIModel):
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
result = _invoke_llm_via_plugin(
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
@ -258,13 +154,38 @@ class LargeLanguageModel(AIModel):
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
content = ""
|
||||
content_list = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
break
|
||||
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
@ -504,21 +425,27 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_before_invoke",
|
||||
invoke=lambda callback: callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
|
||||
def _trigger_new_chunk_callbacks(
|
||||
self,
|
||||
@ -546,22 +473,26 @@ class LargeLanguageModel(AIModel):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_new_chunk",
|
||||
invoke=lambda callback: callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e)
|
||||
|
||||
def _trigger_after_invoke_callbacks(
|
||||
self,
|
||||
@ -590,22 +521,28 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_after_invoke",
|
||||
invoke=lambda callback: callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
|
||||
def _trigger_invoke_error_callbacks(
|
||||
self,
|
||||
@ -634,19 +571,25 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_invoke_error",
|
||||
invoke=lambda callback: callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user