Compare commits

..

68 Commits
bug1 ... 1.8.1

Author SHA1 Message Date
c7700ac176 chore(docker): bump version (#25092)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 20:25:44 +08:00
d011ddfc64 chore(version): bump version to 1.8.1 (#25060) 2025-09-03 18:54:07 +08:00
67cc70ad61 fix: model credential name (#25081)
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 18:23:57 +08:00
a384ae9140 Fix advanced chat workflow event handler signature mismatch (#25078)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 16:22:13 +08:00
a7627882a7 fix: Boolean type control is not displayed (#25031)
Co-authored-by: WTW0313 <twwu@dify.ai>
2025-09-03 15:39:09 +08:00
8eae7a95be Hotfix translation error (#25035) 2025-09-03 15:23:04 +08:00
dabf266048 Fix: handle 204 No Content response in MCP client (#25040) 2025-09-03 15:22:42 +08:00
462e764a3c typevar example (#25064)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 14:54:38 +08:00
0e8a37dca8 chore: translate i18n files (#25061)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
2025-09-03 14:48:53 +08:00
bffbe54120 fix: Solve the problem of opening remarks appearing in the chat cont… (#25067) 2025-09-03 14:48:30 +08:00
b673560b92 feat: improve multi model credentials (#25009)
Co-authored-by: Claude <noreply@anthropic.com>
2025-09-03 13:52:31 +08:00
9e125e2029 Refactor/model credential (#24994) 2025-09-03 13:36:59 +08:00
b88146c443 chore: consolidate type checking in style workflow (#25053) 2025-09-03 13:34:43 +08:00
c40cb7fd59 [Chore/Refactor] Update .gitignore to exclude pyrightconfig.json while preserving api/pyrightconfig.json (#25055) 2025-09-03 13:34:07 +08:00
9d5956cef8 [Chore/Refactor] Switch from MyPy to Basedpyright for type checking (#25047)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-03 11:52:26 +08:00
1fff4620e6 clean console apis and rag cleans. (#25042)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 11:25:18 +08:00
c3820f55f4 chore: translate Chinese comments to English in ClickZetta Volume storage module (#25037) 2025-09-03 10:57:58 +08:00
60c5bdd62f fix: remove redundant z-index from Field component (#25034) 2025-09-03 10:39:07 +08:00
5092e5f631 fix: workflow not published (#25030) 2025-09-03 10:07:31 +08:00
c0bd35594e feat: add test containers based tests for tools manage service (#25028) 2025-09-03 09:20:16 +08:00
bc9efa7ea8 Refactor: use DatasourceType.XX.value instead of hardcoded (#25015)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-03 08:56:48 +08:00
f540d0b747 chore: remove ty type checker from reformat script and pre-commit hooks (#25021) 2025-09-03 08:56:23 +08:00
7bcaa513fa chore: remove duplicate test helper classes from api root directory (#25024) 2025-09-03 08:56:00 +08:00
d33dfee8a3 fix: EndUser is not bound to a Session (#25010) 2025-09-02 21:37:21 +08:00
b5216df4fe fix: xxx is not bound to a Session (#24966) 2025-09-02 21:37:06 +08:00
25a11bfafc Export DSL from history (#24939)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 21:36:52 +08:00
8fcc864fb7 Post fix of #23224 (#25007) 2025-09-02 20:59:08 +08:00
ed5ed0306e minor fix: fix the check of subscription capacity limit (#24991)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 19:14:30 +08:00
a418c43d32 example add more type check (#24999)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 19:13:43 +08:00
5aa8c9c8df fix: refresh UI after user profile change (#24998) 2025-09-02 18:57:35 +08:00
32972b45db fix: remove unnecessary modal visibility toggle on error in name save (#25001) 2025-09-02 18:57:24 +08:00
af351b1723 fix: ensure the modal closed by level (#24984) 2025-09-02 17:06:10 +08:00
af88266212 chore: run ty check CI action only when api code changed (#24986) 2025-09-02 16:59:11 +08:00
b14119b531 feat: add development environment setup commands to Makefile (#24976) 2025-09-02 16:24:21 +08:00
68c75f221b fix: workflow log status filter add parial success status (#24977) 2025-09-02 16:24:03 +08:00
7b379e2a61 chore: apply ty checks on api code with script and ci action (#24653) 2025-09-02 16:05:13 +08:00
c373b734bc feat: make secretInput type field prevent browser auto-fill (#24971) 2025-09-02 16:04:12 +08:00
2ac8f8003f refactor: update radio component to handle boolean values instead of numeric (#24956) 2025-09-02 15:11:42 +08:00
d6b3df8f6f fix: API Key Authorization Configuration Model Form render default value (#24963) 2025-09-02 14:52:05 +08:00
deea07e905 make clean() function in index_processor_base abstractmethod (#24959)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 14:48:45 +08:00
0caa94bd1c fix: add Indonesian (id-ID) language support and improve language selector (#24951) 2025-09-02 14:44:59 +08:00
a32dde5428 Fix: Resolve workflow_node_execution primary key conflicts with UUID v7 (#24643)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 14:18:29 +08:00
067b0d07c4 Fix: ensure InstalledApp deletion uses model instances instead of Row (#24942)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 11:59:38 +08:00
044f96bd93 feat: LLM prompt Jinja2 template now support more variables (#24944) 2025-09-02 11:59:31 +08:00
ca96350707 chore: optimize SQL queries that perform partial full table scans (#24786) 2025-09-02 11:46:11 +08:00
be3af1e234 Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 10:30:19 +08:00
2e89d29c87 chore: translate i18n files (#24934)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-02 10:16:14 +08:00
e4eb9f7c55 fix(i18n): align zh-Hant indexMethodEconomyTip with zh-Hans (#24933) 2025-09-02 09:57:39 +08:00
znn
dd6547de06 downvote with reason (#24922) 2025-09-02 09:57:04 +08:00
84d09b8b8a fix: API key input uses password type and no autocomplete (#24864)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-02 09:37:24 +08:00
2c462154f7 fix: email input cannot scroll (#24930) 2025-09-02 09:35:53 +08:00
b810efdb3f Feature add test containers tool transform service (#24927) 2025-09-02 09:30:55 +08:00
ae04ccc445 fix: npx typo error (#24929) 2025-09-02 09:20:51 +08:00
f7ac1192ae replace the secret field from obfuscated to full-masked value (#24800)
Co-authored-by: charles liu <dearcharles.liu@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-02 09:19:20 +08:00
e048588a88 fix: remove duplicated code (#24893) 2025-09-02 08:58:31 +08:00
2042353526 fix:score threshold (#24897) 2025-09-02 08:58:14 +08:00
9486715929 FEAT: Tencent Vector optimize BM25 initialization to reduce loading time (#24915)
Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
2025-09-01 21:08:41 +08:00
64319c0d56 fix close session twice. (#24917)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 21:08:01 +08:00
acd209a890 fix: prevent database connection leaks in chatflow mode by using Session-managed queries (#24656)
Co-authored-by: 王锶奇 <wangsiqi2@tal.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 18:22:42 +08:00
bd482eb8ef fix wrong filter handle for saved messages (#24891)
Co-authored-by: zhuqingchao <zhuqingchao@xiaomi.com>
2025-09-01 16:32:08 +08:00
5b3cc560d5 fix:hard-coded top-k fallback issue. (#24879) 2025-09-01 15:46:37 +08:00
d41d4deaac example enum to StrEnum (#24877)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 15:40:26 +08:00
208ce4e774 CI: add TS indentation check via esLint (#24810) 2025-09-01 15:31:59 +08:00
414ee51975 fix: add missing form for boolean types (#24812)
Signed-off-by: jingfelix <jingfelix@outlook.com>
2025-09-01 15:21:36 +08:00
d5a521eef2 fix: Fix database connection leak in EasyUIBasedGenerateTaskPipeline (#24815) 2025-09-01 14:48:56 +08:00
1b401063e8 chore: pnpx deprecation (#24868) 2025-09-01 14:45:44 +08:00
60d9d0584a refactor: migrate marketplace.py from requests to httpx (#24015) 2025-09-01 14:28:21 +08:00
ffba341258 [CHORE]: remove redundant-cast (#24807) 2025-09-01 14:05:32 +08:00
427 changed files with 12886 additions and 4665 deletions

View File

@ -42,11 +42,7 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check
run: |
cd api
@ -66,15 +62,6 @@ jobs:
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env

View File

@ -44,6 +44,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
@ -89,7 +93,9 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint
run: |
pnpm run lint
pnpm run eslint
docker-compose-template:
name: Docker Compose Template

5
.gitignore vendored
View File

@ -123,10 +123,12 @@ venv.bak/
# mkdocs documentation
/site
# mypy
# type checking
.mypy_cache/
.dmypy.json
dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker
.pyre/
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
# vscode Code History Extension
.history

View File

@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)

View File

@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -39,5 +81,21 @@ build-push-web: build-web push-web
build-push-all: build-all push-all
@echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help

View File

@ -108,5 +108,5 @@ uv run celery -A app.celery beat
../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code
uv run mypy . # Type checking
uv run basedpyright . # Type checking
```

View File

@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -571,7 +571,7 @@ def old_metadata_migration():
for document in documents:
if document.doc_metadata:
doc_metadata = document.doc_metadata
for key, value in doc_metadata.items():
for key in doc_metadata:
for field in BuiltInField:
if field.value == key:
break

View File

@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content)
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occurred")
raise

View File

@ -27,7 +27,7 @@ class NacosHttpClient:
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"):
@ -77,6 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occur")
raise

View File

@ -19,6 +19,7 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
}
languages = list(language_timezone_mapping.keys())

View File

@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
)
).all()
.scalars()
.all()
)
for installed_app in installed_apps:
db.session.delete(installed_app)
for installed_app in installed_apps:
session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()

View File

@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@ -237,9 +237,14 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource):

View File

@ -526,7 +526,7 @@ class PublishedWorkflowApi(Resource):
)
app_model.workflow_id = workflow.id
db.session.commit()
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)

View File

@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()

View File

@ -80,7 +80,7 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],

View File

@ -40,6 +40,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@ -354,9 +355,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -428,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -488,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@ -506,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

View File

@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
)
if args.get("config_from", "") == "predefined-model":
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@ -410,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id:
user_id = "DEFAULT-USER"
end_user = (
db.session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
session.add(end_user)
session.commit()
return end_user

View File

@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204

View File

@ -4,6 +4,7 @@ from functools import wraps
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -49,18 +50,19 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
# for enterprise webapp auth
app_web_auth_enabled = False

View File

@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

View File

@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
except ValidationError:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,

View File

@ -72,7 +72,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
@ -140,7 +142,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)
# init graph

View File

@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())

View File

@ -73,7 +73,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@ -311,13 +310,8 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
@ -338,15 +332,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
yield node_retry_resp
@ -380,13 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
@ -939,10 +931,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
from enum import Enum
from enum import IntEnum, auto
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class AppQueueManager:
@ -159,7 +159,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -3,6 +3,9 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@ -300,16 +300,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(Enum):
class InvokeFrom(StrEnum):
"""
Invoke From.
"""

View File

@ -1,6 +1,8 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None

View File

@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(

View File

@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
@ -98,7 +101,7 @@ class MessageCycleManager:
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name
except Exception as e:
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(

View File

@ -1,5 +1,6 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -351,8 +410,11 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@ -395,7 +457,7 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@ -406,7 +468,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -428,9 +490,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@ -532,13 +594,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@ -822,7 +878,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -833,10 +889,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED

View File

@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: Optional[bool] = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs

View File

@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout,
proxies=proxies,
)
except requests.exceptions.Timeout:
except requests.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@ -91,7 +91,7 @@ class Extensible:
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
for obj in vars(mod).values():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
@ -123,7 +123,7 @@ class Extensible:
)
)
except Exception as e:
except Exception:
logger.exception("Error scanning extensions")
raise

View File

@ -41,9 +41,3 @@ class Extension:
assert module_extension.extension_class is not None
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(

View File

@ -22,7 +22,6 @@ class ExternalDataToolFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
# FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
import requests
import httpx
from yarl import URL
from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,13 +36,13 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
except Exception:
pass
return result
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source

View File

@ -1,7 +1,7 @@
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@ -72,11 +72,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
T = TypeVar("T")
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
data: T,
name_func: Callable[[T], str],
) -> bool:
"""
Check if the object should be filtered out.
@ -103,9 +106,9 @@ def is_filtered(
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@ -122,9 +125,9 @@ def sort_by_position_map(
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
data: list[T],
name_func: Callable[[T], str],
):
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
@ -134,4 +137,4 @@ def sort_to_dict_by_position_map(
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])
return OrderedDict((name_func(item), item) for item in sorted_items)

View File

@ -5,9 +5,10 @@ import re
import threading
import time
import uuid
from typing import Any, Optional, cast
from typing import Any, Optional
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@ -56,13 +58,11 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@ -123,11 +123,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
@ -208,7 +205,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@ -310,7 +306,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@ -339,14 +336,14 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
@ -357,7 +354,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@ -377,7 +374,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
@ -400,7 +397,6 @@ class IndexingRunner:
)
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
if text_doc.metadata is not None:
text_doc.metadata["document_id"] = dataset_document.id

View File

@ -56,11 +56,8 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
@ -69,7 +66,7 @@ class LLMGenerator:
try:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
except json.JSONDecodeError as e:
except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead")
answer = query
name = answer.strip()
@ -113,13 +110,10 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
)
text_content = response.message.get_text_content()
@ -162,11 +156,8 @@ class LLMGenerator:
)
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
@ -212,11 +203,8 @@ class LLMGenerator:
try:
try:
# the first step to generate the task prompt
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
except InvokeError as e:
error = str(e)
@ -248,11 +236,8 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
@ -260,11 +245,8 @@ class LLMGenerator:
error_step = "generate variables"
try:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
@ -307,11 +289,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
@ -338,13 +317,10 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
)
answer = cast(str, response.message.content)
@ -367,11 +343,8 @@ class LLMGenerator:
model_parameters = model_config.get("model_parameters", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
@ -555,11 +528,8 @@ class LLMGenerator:
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)

View File

@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
if b_query:
url_for_resource_discovery += f"?{b_query}"
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except httpx.RequestError as e:
except httpx.RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@ -246,6 +246,10 @@ class StreamableHTTPTransport:
logger.debug("Received 202 Accepted")
return
if response.status_code == 204:
logger.debug("Received 204 No Content")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
self._send_session_terminated_error(

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from typing import Any, Optional
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
@ -116,8 +116,7 @@ class MCPClient:
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
self._session.initialize()
return
except MCPAuthError:

View File

@ -110,9 +110,9 @@ class TokenBufferMemory:
else:
message_limit = 500
stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)
messages = db.session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message

View File

@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -1,6 +1,7 @@
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
extension = db.session.scalar(stmt)
return extension

View File

@ -20,7 +20,6 @@ class ModerationFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
except Exception:
logger.exception("Moderation Output error, app_id: %s", app_id)
return None

View File

@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -263,15 +264,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

View File

@ -72,7 +72,7 @@ class TraceClient:
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

View File

@ -226,9 +226,9 @@ class OpsTraceManager:
if not trace_config_data:
return None
# decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
@ -295,20 +295,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
@ -850,7 +849,7 @@ class TraceQueueManager:
if self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
except Exception:
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally:
self.start_timer()
@ -869,7 +868,7 @@ class TraceQueueManager:
tasks = self.collect_tasks()
if tasks:
self.send_to_celery(tasks)
except Exception as e:
except Exception:
logger.exception("Error processing trace tasks")
def start_timer(self):

View File

@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union
from sqlalchemy import select
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -192,10 +194,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)
if not user:
raise ValueError("user not found")

View File

@ -64,7 +64,7 @@ class BasePluginClient:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union, cast
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -85,7 +85,7 @@ def merge_blob_chunks(
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
meta=resp.meta,
)
yield cast(MessageType, merged_message)
yield merged_message
# Clean up the buffer
del files[chunk_id]
else:

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -1,6 +1,7 @@
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, Optional, cast
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
QuotaConfiguration,
QuotaUnit,
SystemConfiguration,
UnaddedModelConfiguration,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
@ -154,8 +156,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
@ -276,15 +278,11 @@ class ProviderManager:
:param model_type: model type
:return:
"""
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
@ -367,16 +365,11 @@ class ProviderManager:
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# create or update TenantDefaultModel record
if default_model:
@ -546,6 +539,23 @@ class ProviderManager:
for credential in available_credentials
]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -598,16 +608,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue
@ -635,6 +642,44 @@ class ProviderManager:
:param provider_model_records: provider model records
:return:
"""
# Get custom provider configuration
custom_provider_configuration = self._get_custom_provider_configuration(
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
)
can_added_models = [
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
]
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations,
can_added_models=can_added_models,
)
def _get_custom_provider_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
) -> CustomProviderConfiguration | None:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
)
if not custom_provider_record:
return None
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
@ -642,113 +687,98 @@ class ProviderManager:
else []
)
# Get custom provider record
custom_provider_record = None
for provider_record in provider_records:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
# Get and decrypt provider credentials
provider_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=custom_provider_record.id,
encrypted_config=custom_provider_record.encrypted_config,
secret_variables=provider_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.PROVIDER,
is_provider=True,
)
custom_provider_record = provider_record
return CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
def _get_can_added_models(
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
) -> list[dict]:
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
# Get not added custom models credentials
not_added_custom_models_credentials = [
credential
for credential in all_model_credentials
if (credential.model_name, credential.model_type) not in existing_model_set
]
if not cached_provider_credentials:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Group credentials by model
model_to_credentials = defaultdict(list)
for credential in not_added_custom_models_credentials:
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds
],
}
for model_key, creds in model_to_credentials.items()
]
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
def _get_custom_model_configurations(
self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_model_records: list[ProviderModel],
can_added_models: list[dict],
all_model_credentials: Sequence[ProviderModelCredential],
) -> list[CustomModelConfiguration]:
"""Get custom model configurations."""
# Get model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
# Get custom provider model credentials
# Create credentials lookup for efficient access
credentials_map = defaultdict(list)
for credential in all_model_credentials:
credentials_map[(credential.model_name, credential.model_type)].append(credential)
custom_model_configurations = []
# Process existing model records
for provider_model_record in provider_model_records:
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
# Use pre-fetched credentials instead of individual database calls
available_model_credentials = [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in credentials_map.get(
(provider_model_record.model_name, provider_model_record.model_type), []
)
]
# Get and decrypt model credentials
provider_model_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=provider_model_record.id,
encrypted_config=provider_model_record.encrypted_config,
secret_variables=model_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.MODEL,
is_provider=False,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
@ -760,7 +790,71 @@ class ProviderManager:
)
)
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
# Add models that can be added
for model in can_added_models:
custom_model_configurations.append(
CustomModelConfiguration(
model=model["model"],
model_type=model["model_type"],
credentials=None,
current_credential_id=None,
current_credential_name=None,
available_model_credentials=model["available_model_credentials"],
unadded_to_model_list=True,
)
)
return custom_model_configurations
def _get_and_decrypt_credentials(
self,
tenant_id: str,
record_id: str,
encrypted_config: str | None,
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=record_id,
cache_type=cache_type,
)
# Try to get from cache first
cached_credentials = credentials_cache.get()
if cached_credentials:
return cached_credentials
# Parse encrypted config
if not encrypted_config:
return {}
if is_provider and not encrypted_config.startswith("{"):
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in secret_variables:
if variable in credentials:
with contextlib.suppress(ValueError):
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# Cache the decrypted credentials
credentials_cache.set(credentials=credentials)
return credentials
def _to_system_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
@ -968,18 +1062,6 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -1045,6 +1127,7 @@ class ProviderManager:
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
)
)

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)

View File

@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue

View File

@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
vector=None, # ty: ignore [invalid-argument-type]
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,

View File

@ -228,7 +228,7 @@ class AnalyticdbVectorBySql:
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
_, vector, score, page_content, metadata = record
if score >= score_threshold:
metadata["score"] = score
doc = Document(
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
_, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,

View File

@ -12,7 +12,7 @@ import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
len(data_rows),
vector_dimension,
)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs)
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []

View File

@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc)
_ = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids)
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
"""
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
except Exception:
logger.exception("Failed to delete documents, ids: %s", ids)
def delete_by_document_id(self, document_id: str):
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.exceptions.ConnectionError as e:
except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")

View File

@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
return client
try:
client.create_full_text_index()
except Exception as e:
except Exception:
logger.exception("Failed to create full text index")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return client

View File

@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client

View File

@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e:
except Exception:
logger.exception("Error executing vector search, query: %s", query)
raise

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
from flask import current_app
@ -18,6 +18,7 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -426,7 +427,6 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
@ -446,11 +446,8 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:

View File

@ -71,7 +71,7 @@ class TableStoreVector(BaseVector):
table_result = result.get_result_by_table(self._table_name)
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, t in item.row.attribute_columns}
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(
Document(
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])

View File

@ -39,6 +39,9 @@ class TencentConfig(BaseModel):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
bm25 = BM25Encoder.default("zh")
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
@ -53,7 +56,6 @@ class TencentVector(BaseVector):
self._dimension = 1024
self._init_database()
self._load_collection()
self._bm25 = BM25Encoder.default("zh")
def _load_collection(self):
"""
@ -186,7 +188,7 @@ class TencentVector(BaseVector):
metadata=metadata,
)
if self._enable_hybrid_search:
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
@ -264,7 +266,7 @@ class TencentVector(BaseVector):
match=[
KeywordSearch(
field_name="sparse_vector",
data=self._bm25.encode_queries(query),
data=bm25.encode_queries(query),
),
],
rerank=WeightedRerank(

View File

@ -20,6 +20,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -3,6 +3,8 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Optional
from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -45,11 +47,10 @@ class Vector:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT

View File

@ -32,9 +32,9 @@ class VikingDBConfig(BaseModel):
scheme: str
connection_timeout: int
socket_timeout: int
index_type: str = IndexType.HNSW
distance: str = DistanceType.L2
quant: str = QuantType.Float
index_type: str = str(IndexType.HNSW)
distance: str = str(DistanceType.L2)
quant: str = str(QuantType.Float)
class VikingDBVector(BaseVector):

View File

@ -37,22 +37,22 @@ class WeaviateVector(BaseVector):
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
weaviate.connect.connection.has_grpc = False
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
# which does not contain the deprecation check.
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): # ty: ignore [unresolved-attribute]
weaviate.connect.connection.PYPI_TIMEOUT = 0.001 # ty: ignore [unresolved-attribute]
try:
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any, Optional
from sqlalchemy import func
from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -41,9 +41,8 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
document_segments = db.session.scalars(stmt).all()
output = {}
for document_segment in document_segments:
@ -228,10 +227,9 @@ class DatasetDocumentStore:
return data
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
)
document_segment = db.session.scalar(stmt)
return document_segment

Some files were not shown because too many files have changed in this diff Show More