mirror of
https://github.com/langgenius/dify.git
synced 2026-02-16 08:15:18 +08:00
Compare commits
33 Commits
feat/suppo
...
0.12.0
| Author | SHA1 | Date | |
|---|---|---|---|
| 625aaceb00 | |||
| 98d85e6b74 | |||
| 319d49084b | |||
| eb542067af | |||
| 04b9a2c605 | |||
| 8028e75fbb | |||
| 3eb51d85da | |||
| 79a35c2fe6 | |||
| 2dd4c34423 | |||
| 684f6b2299 | |||
| b791a80b75 | |||
| 13006f94e2 | |||
| 41772c325f | |||
| a4fc057a1c | |||
| aae29e72ae | |||
| 87c831e5dd | |||
| 40a5f1c80a | |||
| 04f1e18342 | |||
| 365a40d11f | |||
| 60b5dac3ab | |||
| 8565c18e84 | |||
| 03ba4bc760 | |||
| ae3a2cb272 | |||
| 6c8e208ef3 | |||
| 0181f1c08c | |||
| 7f00c5a02e | |||
| d0648e27e2 | |||
| 31348af2e3 | |||
| 096c0ad564 | |||
| 16c41585e1 | |||
| 566ab9261d | |||
| 1cdadfdece | |||
| 448a19bf54 |
@ -1,5 +1,5 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.10
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||
|
||||
# [Optional] Uncomment this section to install additional OS packages.
|
||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
|
||||
{
|
||||
"name": "Python 3.10",
|
||||
"name": "Python 3.12",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
|
||||
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@ -4,7 +4,7 @@ inputs:
|
||||
python-version:
|
||||
description: Python version to use and the Poetry installed with
|
||||
required: true
|
||||
default: '3.10'
|
||||
default: '3.11'
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
|
||||
1
.github/workflows/api-tests.yml
vendored
1
.github/workflows/api-tests.yml
vendored
@ -20,7 +20,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
|
||||
3
.github/workflows/vdb-tests.yml
vendored
3
.github/workflows/vdb-tests.yml
vendored
@ -8,6 +8,8 @@ on:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/poetry.lock
|
||||
- api/pyproject.toml
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -20,7 +22,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ Dify 依赖以下工具和库:
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. 安装
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. インストール
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) phiên bản 3.10.x
|
||||
- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
|
||||
|
||||
### 4. Cài đặt
|
||||
|
||||
@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ
|
||||
|
||||
## Nhận trợ giúp
|
||||
|
||||
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.
|
||||
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
python_version = sys.version_info
|
||||
if not ((3, 11) <= python_version < (3, 13)):
|
||||
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
|
||||
raise SystemExit(1)
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
if sys.version_info[:2] == (3, 10):
|
||||
print("Warning: Python 3.10 will not be supported in the next version.")
|
||||
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.11.2",
|
||||
default="0.12.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -62,7 +62,6 @@ from .datasets import (
|
||||
external,
|
||||
hit_testing,
|
||||
website,
|
||||
fta_test,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if not conversation.read_at:
|
||||
conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
conversation.read_account_id = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
@ -75,7 +75,7 @@ class AppSite(Resource):
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
@ -65,7 +65,7 @@ class ActivateApi(Resource):
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
@ -106,7 +106,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
||||
@ -83,7 +83,7 @@ class DataSourceApi(Resource):
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -92,7 +92,7 @@ class DataSourceApi(Resource):
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.paused_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
@ -745,7 +745,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "Document metadata updated."}, 200
|
||||
@ -787,7 +787,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
document.enabled = True
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
@ -804,9 +804,9 @@ class DocumentStatusApi(DocumentResource):
|
||||
raise InvalidActionError("Document already disabled.")
|
||||
|
||||
document.enabled = False
|
||||
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.disabled_by = current_user.id
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
@ -821,9 +821,9 @@ class DocumentStatusApi(DocumentResource):
|
||||
raise InvalidActionError("Document already archived.")
|
||||
|
||||
document.archived = True
|
||||
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.archived_by = current_user.id
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
if document.enabled:
|
||||
@ -840,7 +840,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pandas as pd
|
||||
from flask import request
|
||||
@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
raise InvalidActionError("Segment is already disabled.")
|
||||
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import text
|
||||
|
||||
from controllers.console import api
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.fta import ComponentFailure, ComponentFailureStats
|
||||
|
||||
|
||||
class FATTestApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
print(args["log_process_data"])
|
||||
# Extract the JSON string from the text field
|
||||
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
|
||||
log_data = json.loads(json_str)
|
||||
db.session.query(ComponentFailure).delete()
|
||||
for data in log_data:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError("Data must be a dictionary.")
|
||||
|
||||
required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")
|
||||
|
||||
try:
|
||||
# Clear existing stats
|
||||
component_failure = ComponentFailure(
|
||||
Date=data["Date"],
|
||||
Component=data["Component"],
|
||||
FailureMode=data["FailureMode"],
|
||||
Cause=data["Cause"],
|
||||
RepairAction=data["RepairAction"],
|
||||
Technician=data["Technician"],
|
||||
)
|
||||
db.session.add(component_failure)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Clear existing stats
|
||||
db.session.query(ComponentFailureStats).delete()
|
||||
|
||||
# Insert calculated statistics
|
||||
try:
|
||||
db.session.execute(
|
||||
text("""
|
||||
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
|
||||
SELECT
|
||||
cf."Component",
|
||||
cf."FailureMode",
|
||||
cf."Cause",
|
||||
cf."RepairAction" as "PossibleAction",
|
||||
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
|
||||
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
|
||||
FROM (
|
||||
SELECT
|
||||
"Component",
|
||||
"FailureMode",
|
||||
"Cause",
|
||||
"RepairAction",
|
||||
"Date",
|
||||
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
|
||||
FROM
|
||||
component_failure
|
||||
) cf
|
||||
GROUP BY
|
||||
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
|
||||
""")
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
print(f"Error during stats calculation: {e}")
|
||||
# output format
|
||||
# [
|
||||
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
|
||||
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
|
||||
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
|
||||
# ]
|
||||
|
||||
component_failure_stats = db.session.query(ComponentFailureStats).all()
|
||||
# Convert stats to list of tuples format
|
||||
stats_list = []
|
||||
for stat in component_failure_stats:
|
||||
stats_list.append(
|
||||
(
|
||||
stat.StatID,
|
||||
stat.Component,
|
||||
stat.FailureMode,
|
||||
stat.Cause,
|
||||
stat.PossibleAction,
|
||||
stat.Probability,
|
||||
stat.MTBF,
|
||||
)
|
||||
)
|
||||
return {"data": stats_list}, 200
|
||||
|
||||
|
||||
# generate-fault-tree
|
||||
class GenerateFaultTreeApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
|
||||
print(entities)
|
||||
request_data = {"fault_tree_text": entities}
|
||||
url = "https://fta.cognitech-dev.live/generate-fault-tree"
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
response = requests.post(url, json=request_data, headers=headers)
|
||||
print(response.json())
|
||||
return {"data": response.json()}, 200
|
||||
|
||||
|
||||
class ExtractSVGApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
# svg_text = ''.join(args["svg_text"].splitlines())
|
||||
svg_text = args["svg_text"].replace("\n", "")
|
||||
svg_text = svg_text.replace('"', '"')
|
||||
print(svg_text)
|
||||
svg_text_json = json.loads(svg_text)
|
||||
svg_content = svg_text_json.get("data").get("svg_content")[0]
|
||||
svg_content = svg_content.replace("\n", "").replace('"', '"')
|
||||
file_key = "fta_svg/" + "fat.svg"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, svg_content.encode("utf-8"))
|
||||
generator = storage.load(file_key, stream=True)
|
||||
|
||||
return Response(generator, mimetype="image/svg+xml")
|
||||
|
||||
|
||||
api.add_resource(FATTestApi, "/fta/db-handler")
|
||||
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
|
||||
api.add_resource(ExtractSVGApi, "/fta/extract-svg")
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource):
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource):
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
last_used_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(new_installed_app)
|
||||
db.session.commit()
|
||||
|
||||
@ -60,7 +60,7 @@ class AccountInitApi(Resource):
|
||||
raise InvalidInvitationCodeError()
|
||||
|
||||
invitation_code.status = "used"
|
||||
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
@ -68,7 +68,7 @@ class AccountInitApi(Resource):
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None):
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -412,7 +412,7 @@ class BaseAgentRunner(AppRunner):
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
@ -25,7 +26,9 @@ class PromptTemplateConfigManager:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
AdvancedChatMessageEntity(
|
||||
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntityType(str, Enum):
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
|
||||
@ -127,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
||||
@ -134,7 +134,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
@ -6,7 +6,7 @@ from core.file import File, FileUploadConfig
|
||||
from factories import file_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
@ -14,23 +14,23 @@ class BaseAppGenerator:
|
||||
self,
|
||||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
app_config: "AppConfig",
|
||||
variables: Sequence["VariableEntity"],
|
||||
tenant_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
user_inputs = {
|
||||
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
|
||||
for var in variables
|
||||
}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
entity_dictionary = {item.variable: item for item in variables}
|
||||
# Convert single file to File
|
||||
files_inputs = {
|
||||
k: file_factory.build_from_mapping(
|
||||
mapping=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
@ -44,7 +44,7 @@ class BaseAppGenerator:
|
||||
file_list_inputs = {
|
||||
k: file_factory.build_from_mappings(
|
||||
mappings=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
|
||||
@ -132,7 +132,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
||||
@ -113,7 +113,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
|
||||
@ -96,7 +96,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
|
||||
@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.iteration import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
class QueueEvent(StrEnum):
|
||||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -80,38 +81,38 @@ class WorkflowCycleManage:
|
||||
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
if workflow_run_id:
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = WorkflowRun()
|
||||
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
workflow_run.id = system_id or str(uuid4())
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
@ -144,7 +145,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
@ -191,7 +192,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@ -211,7 +212,7 @@ class WorkflowCycleManage:
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
@ -262,7 +263,7 @@ class WorkflowCycleManage:
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
session.commit()
|
||||
@ -285,7 +286,7 @@ class WorkflowCycleManage:
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
@ -329,7 +330,7 @@ class WorkflowCycleManage:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
@ -657,7 +658,7 @@ class WorkflowCycleManage:
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
|
||||
@ -240,7 +240,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if provider_record:
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider(
|
||||
@ -394,7 +394,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if provider_model_record:
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel(
|
||||
@ -468,7 +468,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
@ -503,7 +503,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
@ -570,7 +570,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
@ -605,7 +605,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
class FileType(StrEnum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
@ -16,7 +16,7 @@ class FileType(str, Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(str, Enum):
|
||||
class FileTransferMethod(StrEnum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(str, Enum):
|
||||
class FileBelongsTo(StrEnum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(str, Enum):
|
||||
class FileAttribute(StrEnum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
@ -51,5 +51,5 @@ class FileAttribute(str, Enum):
|
||||
EXTENSION = "extension"
|
||||
|
||||
|
||||
class ArrayFileAttribute(str, Enum):
|
||||
class ArrayFileAttribute(StrEnum):
|
||||
LENGTH = "length"
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
@ -20,38 +18,6 @@ from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def download_to_target_path(f: File, temp_dir: str, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
suffix = Path(tool_file.file_key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(tool_file.file_key, target_path)
|
||||
return target_path
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
suffix = Path(upload_file.key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(upload_file.key, target_path)
|
||||
return target_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_to_target_path(path: str, target_path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
target_path (str): The path to the target file.
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
storage.download(path, target_path)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel):
|
||||
data: Data
|
||||
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
@ -86,7 +86,7 @@ class IndexingRunner:
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning("Document deleted, document id: {}".format(dataset_document.id))
|
||||
@ -94,7 +94,7 @@ class IndexingRunner:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
@ -142,13 +142,13 @@ class IndexingRunner:
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||
@ -200,13 +200,13 @@ class IndexingRunner:
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def indexing_estimate(
|
||||
@ -372,7 +372,7 @@ class IndexingRunner:
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
||||
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
|
||||
@ -464,7 +464,7 @@ class IndexingRunner:
|
||||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
@ -479,7 +479,7 @@ class IndexingRunner:
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
|
||||
@ -680,7 +680,7 @@ class IndexingRunner:
|
||||
after_indexing_status="completed",
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
DatasetDocument.error: None,
|
||||
},
|
||||
@ -705,7 +705,7 @@ class IndexingRunner:
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
@ -738,7 +738,7 @@ class IndexingRunner:
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
@ -849,7 +849,7 @@ class IndexingRunner:
|
||||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
@ -864,7 +864,7 @@ class IndexingRunner:
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
},
|
||||
)
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -49,7 +49,7 @@ class PromptMessageFunction(BaseModel):
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(str, Enum):
|
||||
class PromptMessageContentType(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
@ -95,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(str, Enum):
|
||||
class DETAIL(StrEnum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -92,7 +92,7 @@ class ModelFeature(Enum):
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class DefaultParameterName(str, Enum):
|
||||
class DefaultParameterName(StrEnum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
@ -15,9 +15,9 @@ parameter_rules:
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
|
||||
@ -16,9 +16,9 @@ parameter_rules:
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
|
||||
@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
@ -72,7 +73,7 @@ parameter_rules:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '2'
|
||||
unit: '0.000001'
|
||||
input: "1"
|
||||
output: "2"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
|
||||
@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
|
||||
@ -1,18 +1,17 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
@ -25,92 +24,15 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(
|
||||
self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["mode"] = "chat"
|
||||
credentials["openai_api_key"] = credentials["api_key"]
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["openai_api_base"] = "https://api.deepseek.com"
|
||||
else:
|
||||
parsed_url = urlparse(credentials["endpoint_url"])
|
||||
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials.get("endpoint_url", "https://api.deepseek.com")))
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
credentials["stream_function_calling"] = "support"
|
||||
|
||||
@ -18,7 +18,8 @@ class FishAudioProvider(ModelProvider):
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TTS)
|
||||
model_instance.validate_credentials(credentials=credentials)
|
||||
# FIXME fish tts do not have model for now, so set it to empty string instead
|
||||
model_instance.validate_credentials(model="", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@ -66,7 +66,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Validate credentials for text2speech model
|
||||
|
||||
@ -76,7 +76,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
|
||||
try:
|
||||
self.get_tts_model_voices(
|
||||
None,
|
||||
"",
|
||||
credentials={
|
||||
"api_key": credentials["api_key"],
|
||||
"api_base": credentials["api_base"],
|
||||
|
||||
@ -34,3 +34,11 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
||||
@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
|
||||
|
||||
rerank_documents = []
|
||||
for result in results:
|
||||
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
if extra_args.model_type != "reranker":
|
||||
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
||||
|
||||
|
||||
@ -26,13 +26,15 @@ cache_lock = Lock()
|
||||
|
||||
class TeiHelper:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
def get_tei_extra_parameter(
|
||||
server_url: str, model_name: str, headers: Optional[dict] = None
|
||||
) -> TeiModelExtraParameter:
|
||||
TeiHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_name not in cache:
|
||||
cache[model_name] = {
|
||||
"expires": time() + 300,
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url),
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
|
||||
}
|
||||
return cache[model_name]["value"]
|
||||
|
||||
@ -47,7 +49,7 @@ class TeiHelper:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
||||
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
|
||||
"""
|
||||
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||
"""
|
||||
@ -61,7 +63,7 @@ class TeiHelper:
|
||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
response = session.get(url, headers=headers, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
||||
if response.status_code != 200:
|
||||
@ -86,7 +88,7 @@ class TeiHelper:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
|
||||
"""
|
||||
Invoke tokenize endpoint
|
||||
|
||||
@ -114,15 +116,15 @@ class TeiHelper:
|
||||
:param server_url: server url
|
||||
:param texts: texts to tokenize
|
||||
"""
|
||||
resp = httpx.post(
|
||||
f"{server_url}/tokenize",
|
||||
json={"inputs": texts},
|
||||
)
|
||||
url = f"{server_url}/tokenize"
|
||||
json_data = {"inputs": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Invoke embeddings endpoint
|
||||
|
||||
@ -147,15 +149,14 @@ class TeiHelper:
|
||||
:param texts: texts to embed
|
||||
"""
|
||||
# Use OpenAI compatible API here, which has usage tracking
|
||||
resp = httpx.post(
|
||||
f"{server_url}/v1/embeddings",
|
||||
json={"input": texts},
|
||||
)
|
||||
url = f"{server_url}/v1/embeddings"
|
||||
json_data = {"input": texts}
|
||||
resp = httpx.post(url, json=json_data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
||||
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
|
||||
"""
|
||||
Invoke rerank endpoint
|
||||
|
||||
@ -173,10 +174,7 @@ class TeiHelper:
|
||||
:param candidates: candidates to rerank
|
||||
"""
|
||||
params = {"query": query, "texts": docs, "return_text": True}
|
||||
|
||||
response = httpx.post(
|
||||
server_url + "/rerank",
|
||||
json=params,
|
||||
)
|
||||
url = f"{server_url}/rerank"
|
||||
response = httpx.post(url, json=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials["api_key"]
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
used_tokens = 0
|
||||
|
||||
# get tokenized results from TEI
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
|
||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||
# Check if the number of tokens is larger than the context size
|
||||
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
used_tokens = 0
|
||||
for i in _iter:
|
||||
iter_texts = inputs[i : i + max_chunks]
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
|
||||
embeddings = results["data"]
|
||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||
batched_embeddings.extend(embeddings)
|
||||
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
server_url = server_url.removesuffix("/")
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
}
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||
return num_tokens
|
||||
|
||||
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
try:
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||
print(extra_args)
|
||||
if extra_args.model_type != "embedding":
|
||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||
|
||||
@ -24,4 +24,3 @@
|
||||
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-9b-it
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
|
||||
@ -18,6 +18,7 @@ supported_model_types:
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
model: fishaudio/fish-speech-1.4
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'fishaudio/fish-speech-1.4:alex'
|
||||
voices:
|
||||
- mode: "fishaudio/fish-speech-1.4:alex"
|
||||
name: "Alex(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:benjamin"
|
||||
name: "Benjamin(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:charles"
|
||||
name: "Charles(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:david"
|
||||
name: "David(男声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:anna"
|
||||
name: "Anna(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:bella"
|
||||
name: "Bella(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:claire"
|
||||
name: "Claire(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
- mode: "fishaudio/fish-speech-1.4:diana"
|
||||
name: "Diana(女声)"
|
||||
language: [ "zh-Hans", "en-US" ]
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
# stream: false
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
105
api/core/model_runtime/model_providers/siliconflow/tts/tts.py
Normal file
105
api/core/model_runtime/model_providers/siliconflow/tts/tts.py
Normal file
@ -0,0 +1,105 @@
|
||||
import concurrent.futures
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
|
||||
class SiliconFlowText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for SiliconFlow Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
if not voice or voice not in [
|
||||
d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials)
|
||||
]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
# if streaming:
|
||||
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text="Hello SiliconFlow!",
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
# doc: https://docs.siliconflow.cn/capabilities/text-to-speech
|
||||
self._add_custom_parameters(credentials)
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
model_support_voice = [
|
||||
x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials)
|
||||
]
|
||||
if not voice or voice not in model_support_voice:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if len(content_text) > 4096:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=4096)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [
|
||||
executor.submit(
|
||||
client.audio.speech.with_streaming_response.create,
|
||||
model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i],
|
||||
voice=voice,
|
||||
)
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(
|
||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||
)
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict) -> None:
|
||||
credentials["openai_api_base"] = "https://api.siliconflow.cn"
|
||||
credentials["openai_api_key"] = credentials["api_key"]
|
||||
@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import (
|
||||
)
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_INVOKE_TIMEOUT = 60
|
||||
|
||||
|
||||
class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
timeout=60,
|
||||
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
||||
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
||||
)
|
||||
|
||||
xinference_client = Client(
|
||||
|
||||
@ -56,3 +56,23 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的API密钥
|
||||
en_US: Enter the api key
|
||||
- variable: invoke_timeout
|
||||
label:
|
||||
zh_Hans: 调用超时时间 (单位:秒)
|
||||
en_US: invoke timeout (unit:second)
|
||||
type: text-input
|
||||
required: true
|
||||
default: '60'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入调用超时时间
|
||||
en_US: Enter invoke timeout value
|
||||
- variable: max_retries
|
||||
label:
|
||||
zh_Hans: 调用重试次数
|
||||
en_US: max retries
|
||||
type: text-input
|
||||
required: true
|
||||
default: '3'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入调用重试次数
|
||||
en_US: Enter max retries
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
@ -122,7 +122,7 @@ trace_info_info_map = {
|
||||
}
|
||||
|
||||
|
||||
class TraceTaskName(str, Enum):
|
||||
class TraceTaskName(StrEnum):
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
MESSAGE_TRACE = "message"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@ -39,7 +39,7 @@ def validate_input_output(v, field_name):
|
||||
return v
|
||||
|
||||
|
||||
class LevelEnum(str, Enum):
|
||||
class LevelEnum(StrEnum):
|
||||
DEBUG = "DEBUG"
|
||||
WARNING = "WARNING"
|
||||
ERROR = "ERROR"
|
||||
@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel):
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class UnitEnum(str, Enum):
|
||||
class UnitEnum(StrEnum):
|
||||
CHARACTERS = "CHARACTERS"
|
||||
TOKENS = "TOKENS"
|
||||
SECONDS = "SECONDS"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class LangSmithRunType(str, Enum):
|
||||
class LangSmithRunType(StrEnum):
|
||||
tool = "tool"
|
||||
chain = "chain"
|
||||
llm = "llm"
|
||||
|
||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(str, enum.Enum):
|
||||
class ModelMode(enum.StrEnum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class KeyWordType(str, Enum):
|
||||
class KeyWordType(StrEnum):
|
||||
JIEBA = "jieba"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
class VectorType(StrEnum):
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
|
||||
@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor):
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RerankMode(str, Enum):
|
||||
class RerankMode(StrEnum):
|
||||
RERANKING_MODEL = "reranking_model"
|
||||
WEIGHTED_SCORE = "weighted_score"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -137,7 +137,7 @@ class ToolParameterOption(BaseModel):
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(str, Enum):
|
||||
class ToolParameterType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 4.3 KiB |
@ -1,8 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class FileExtractorProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -1,15 +0,0 @@
|
||||
identity:
|
||||
author: Jyong
|
||||
name: file_extractor
|
||||
label:
|
||||
en_US: File Extractor
|
||||
zh_Hans: 文件提取
|
||||
pt_BR: File Extractor
|
||||
description:
|
||||
en_US: Extract text from file
|
||||
zh_Hans: 从文件中提取文本
|
||||
pt_BR: Extract text from file
|
||||
icon: icon.png
|
||||
tags:
|
||||
- utilities
|
||||
- productivity
|
||||
@ -1,45 +0,0 @@
|
||||
import tempfile
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download_to_target_path
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class FileExtractorTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# image file for workflow mode
|
||||
file = tool_parameters.get("text_file")
|
||||
if file and file.type != FileType.DOCUMENT:
|
||||
raise ToolParameterValidationError("Not a valid document")
|
||||
|
||||
if file:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = download_to_target_path(file, temp_dir)
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
documents = extractor.extract()
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=tool_parameters.get("max_token", 500),
|
||||
chunk_overlap=0,
|
||||
fixed_separator=tool_parameters.get("separator", "\n\n"),
|
||||
separators=["\n\n", "。", ". ", " ", ""],
|
||||
embedding_model_instance=None,
|
||||
)
|
||||
chunks = character_splitter.split_documents(documents)
|
||||
|
||||
content = "\n".join([chunk.page_content for chunk in chunks])
|
||||
return self.create_text_message(content)
|
||||
|
||||
else:
|
||||
raise ToolParameterValidationError("Please provide either file")
|
||||
@ -1,49 +0,0 @@
|
||||
identity:
|
||||
name: text extractor
|
||||
author: Jyong
|
||||
label:
|
||||
en_US: Text extractor
|
||||
zh_Hans: Text 文本解析
|
||||
description:
|
||||
en_US: Extract content from text file and support split to chunks by split characters and token length
|
||||
zh_Hans: 支持从文本文件中提取内容并支持通过分割字符和令牌长度分割成块
|
||||
pt_BR: Extract content from text file and support split to chunks by split characters and token length
|
||||
description:
|
||||
human:
|
||||
en_US: Text extractor is a text extract tool
|
||||
zh_Hans: Text extractor 是一个文本提取工具
|
||||
pt_BR: Text extractor is a text extract tool
|
||||
llm: Text extractor is a tool used to extract text file
|
||||
parameters:
|
||||
- name: text_file
|
||||
type: file
|
||||
label:
|
||||
en_US: Text file
|
||||
human_description:
|
||||
en_US: The text file to be extracted.
|
||||
zh_Hans: 要提取的 text 文档。
|
||||
llm_description: you should not input this parameter. just input the image_id.
|
||||
form: llm
|
||||
- name: separator
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: split character
|
||||
zh_Hans: 分隔符号
|
||||
human_description:
|
||||
en_US: Text content split character
|
||||
zh_Hans: 用于文档分隔的符号
|
||||
llm_description: it is used for split content to chunks
|
||||
form: form
|
||||
- name: max_token
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum chunk length
|
||||
zh_Hans: 最大分段长度
|
||||
human_description:
|
||||
en_US: Maximum chunk length
|
||||
zh_Hans: 最大分段长度
|
||||
llm_description: it is used for limit chunk's max length
|
||||
form: form
|
||||
|
||||
@ -69,14 +69,16 @@ class GitlabFilesTool(BuiltinTool):
|
||||
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||
)
|
||||
else: # It's a file
|
||||
encoded_item_path = urllib.parse.quote(item_path, safe="")
|
||||
if is_repository:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
|
||||
f"/{item_path}/raw?ref={branch}"
|
||||
f"/{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
else:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files"
|
||||
f"{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
|
||||
@ -149,7 +149,7 @@ class SlidesGeneratorTool(BuiltinTool):
|
||||
presentation_bytes = await self._fetch_presentation(session, download_url)
|
||||
|
||||
return [
|
||||
self.create_text_message("Presentation generated successfully"),
|
||||
self.create_text_message(download_url),
|
||||
self.create_blob_message(
|
||||
blob=presentation_bytes,
|
||||
meta={"mime_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from pytz import timezone as pytz_timezone
|
||||
@ -20,7 +20,7 @@ class CurrentTimeTool(BuiltinTool):
|
||||
tz = tool_parameters.get("timezone", "UTC")
|
||||
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
|
||||
if tz == "UTC":
|
||||
return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}")
|
||||
return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
|
||||
|
||||
try:
|
||||
tz = pytz_timezone(tz)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
@ -62,7 +62,7 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
class VariableKey(str, Enum):
|
||||
class VariableKey(StrEnum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -61,7 +61,12 @@ class ToolEngine:
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {parameters[0].name: tool_parameters}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
try:
|
||||
tool_parameters = json.loads(tool_parameters)
|
||||
except Exception as e:
|
||||
pass
|
||||
if not isinstance(tool_parameters, dict):
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
# invoke the tool
|
||||
try:
|
||||
@ -158,7 +163,7 @@ class ToolEngine:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(timezone.utc)
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
error=None,
|
||||
@ -176,7 +181,7 @@ class ToolEngine:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(timezone.utc)
|
||||
ended_at = datetime.now(UTC)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
|
||||
return meta, response
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class SegmentType(str, Enum):
|
||||
class SegmentType(StrEnum):
|
||||
NONE = "none"
|
||||
NUMBER = "number"
|
||||
STRING = "string"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunMetadataKey(str, Enum):
|
||||
class NodeRunMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
@ -36,7 +36,7 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class SystemVariableKey(str, Enum):
|
||||
class SystemVariableKey(StrEnum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@ -63,7 +63,7 @@ class RouteNodeState(BaseModel):
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
self.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel):
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
|
||||
@ -108,7 +108,7 @@ class Executor:
|
||||
self.content = self.variable_pool.convert_template(data[0].value).text
|
||||
case "json":
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string)
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
@ -6,7 +6,7 @@ from pydantic import Field
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class ErrorHandleMode(str, Enum):
|
||||
class ErrorHandleMode(StrEnum):
|
||||
TERMINATED = "terminated"
|
||||
CONTINUE_ON_ERROR = "continue-on-error"
|
||||
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
@ -135,7 +135,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
@ -367,7 +367,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
@ -440,7 +440,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@ -461,7 +461,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@ -503,7 +503,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
|
||||
@ -38,6 +38,7 @@ from core.variables import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -133,11 +134,15 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
query = None
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
else:
|
||||
query = None
|
||||
if query is None and (
|
||||
query_variable := self.graph_runtime_state.variable_pool.get(
|
||||
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
|
||||
)
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
user_query=query,
|
||||
|
||||
@ -250,9 +250,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
f"{message.message}"
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT
|
||||
else f"Link: {message.message}"
|
||||
if message.type == ToolInvokeMessage.MessageType.LINK
|
||||
else ""
|
||||
for message in tool_response
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
class WriteMode(StrEnum):
|
||||
OVER_WRITE = "over-write"
|
||||
APPEND = "append"
|
||||
CLEAR = "clear"
|
||||
|
||||
@ -5,10 +5,9 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileUploadConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File, FileTransferMethod, ImageConfig
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@ -18,9 +17,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.llm import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
@ -115,7 +113,12 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict
|
||||
cls,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -135,13 +138,9 @@ class WorkflowEntry:
|
||||
raise ValueError("nodes not found in workflow graph")
|
||||
|
||||
# fetch node config from node id
|
||||
node_config = None
|
||||
for node in nodes:
|
||||
if node.get("id") == node_id:
|
||||
node_config = node
|
||||
break
|
||||
|
||||
if not node_config:
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise ValueError("node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
@ -153,11 +152,7 @@ class WorkflowEntry:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
@ -183,28 +178,24 @@ class WorkflowEntry:
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
@ -231,12 +222,11 @@ class WorkflowEntry:
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
*,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
) -> None:
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
@ -254,40 +244,21 @@ class WorkflowEntry:
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = cast(list[str], variable_key_list)
|
||||
variable_key_list = list(variable_key_list)
|
||||
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
# FIXME: temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(input_value, list):
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in input_value:
|
||||
if isinstance(item, dict) and "type" in item and item["type"] == "image":
|
||||
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
|
||||
mapping = {
|
||||
"id": item.get("id"),
|
||||
"transfer_method": transfer_method,
|
||||
"upload_file_id": item.get("upload_file_id"),
|
||||
"url": item.get("url"),
|
||||
}
|
||||
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
input_value = new_value
|
||||
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
|
||||
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
|
||||
if (
|
||||
isinstance(input_value, list)
|
||||
and all(isinstance(item, dict) for item in input_value)
|
||||
and all("type" in item and "transfer_method" in item for item in input_value)
|
||||
):
|
||||
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
@ -33,7 +33,7 @@ def handle(sender, **kwargs):
|
||||
raise NotFound("Document not found")
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from events.message_event import message_was_created
|
||||
@ -17,5 +17,5 @@ def handle(sender, **kwargs):
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||
).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
|
||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
||||
db.session.commit()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
|
||||
@ -67,7 +67,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
account_key=self.account_key,
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1),
|
||||
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
|
||||
)
|
||||
redis_client.set(cache_key, sas_token, ex=3000)
|
||||
return BlobServiceClient(account_url=self.account_url, credential=sas_token)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
class StorageType(StrEnum):
|
||||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@ -71,7 +72,12 @@ def build_from_mapping(
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
|
||||
if not _is_file_valid_with_config(file=file, config=config):
|
||||
if not _is_file_valid_with_config(
|
||||
input_file_type=mapping.get("type", FileType.CUSTOM),
|
||||
file_extension=file.extension,
|
||||
file_transfer_method=file.transfer_method,
|
||||
config=config,
|
||||
):
|
||||
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||
|
||||
return file
|
||||
@ -80,12 +86,9 @@ def build_from_mapping(
|
||||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileUploadConfig | None,
|
||||
config: FileUploadConfig | None = None,
|
||||
tenant_id: str,
|
||||
) -> Sequence[File]:
|
||||
if not config:
|
||||
return []
|
||||
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
@ -96,13 +99,14 @@ def build_from_mappings(
|
||||
]
|
||||
|
||||
if (
|
||||
config
|
||||
# If image config is set.
|
||||
config.image_config
|
||||
and config.image_config
|
||||
# And the number of image files exceeds the maximum limit
|
||||
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
|
||||
):
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
|
||||
if config.number_limits and len(files) > config.number_limits:
|
||||
if config and config.number_limits and len(files) > config.number_limits:
|
||||
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
|
||||
|
||||
return files
|
||||
@ -114,17 +118,18 @@ def _build_from_local_file(
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
file_type = FileType.value_of(mapping.get("type"))
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == mapping.get("upload_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
row = db.session.scalar(stmt)
|
||||
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = FileType(mapping.get("type"))
|
||||
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
@ -152,11 +157,14 @@ def _build_from_remote_url(
|
||||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
|
||||
|
||||
file_type = FileType(mapping.get("type"))
|
||||
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
@ -171,6 +179,7 @@ def _get_remote_file_info(url: str):
|
||||
mime_type = mimetypes.guess_type(filename)[0] or ""
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
resp = cast(httpx.Response, resp)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
if content_disposition := resp.headers.get("Content-Disposition"):
|
||||
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||
@ -180,20 +189,6 @@ def _get_remote_file_info(url: str):
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
||||
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
|
||||
if "image" in mime_type:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in mime_type:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in mime_type:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
return file_type
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
@ -213,7 +208,8 @@ def _build_from_tool_file(
|
||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
|
||||
file_type = FileType(mapping.get("type"))
|
||||
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
@ -229,18 +225,72 @@ def _build_from_tool_file(
|
||||
)
|
||||
|
||||
|
||||
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
|
||||
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
|
||||
def _is_file_valid_with_config(
|
||||
*,
|
||||
input_file_type: str,
|
||||
file_extension: str,
|
||||
file_transfer_method: FileTransferMethod,
|
||||
config: FileUploadConfig,
|
||||
) -> bool:
|
||||
if (
|
||||
config.allowed_file_types
|
||||
and input_file_type not in config.allowed_file_types
|
||||
and input_file_type != FileType.CUSTOM
|
||||
):
|
||||
return False
|
||||
|
||||
if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions:
|
||||
if (
|
||||
input_file_type == FileType.CUSTOM
|
||||
and config.allowed_file_extensions is not None
|
||||
and file_extension not in config.allowed_file_extensions
|
||||
):
|
||||
return False
|
||||
|
||||
if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods:
|
||||
if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
|
||||
return False
|
||||
|
||||
if file.type == FileType.IMAGE and config.image_config:
|
||||
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
|
||||
if input_file_type == FileType.IMAGE and config.image_config:
|
||||
if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
If custom type, try to guess the file type by extension and mime_type.
|
||||
"""
|
||||
if file_type != FileType.CUSTOM:
|
||||
return FileType(file_type)
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
if guessed_type is None and mime_type:
|
||||
guessed_type = _get_file_type_by_mimetype(mime_type)
|
||||
return guessed_type or FileType.CUSTOM
|
||||
|
||||
|
||||
def _get_file_type_by_extension(extension: str) -> FileType | None:
|
||||
extension = extension.lstrip(".")
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
return FileType.IMAGE
|
||||
elif extension in VIDEO_EXTENSIONS:
|
||||
return FileType.VIDEO
|
||||
elif extension in AUDIO_EXTENSIONS:
|
||||
return FileType.AUDIO
|
||||
elif extension in DOCUMENT_EXTENSIONS:
|
||||
return FileType.DOCUMENT
|
||||
|
||||
|
||||
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
|
||||
if "image" in mime_type:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in mime_type:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in mime_type:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
return file_type
|
||||
|
||||
@ -70,7 +70,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
@ -106,7 +106,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
@ -141,7 +141,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
@ -8,7 +8,7 @@ from extensions.ext_database import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class AccountStatus(str, enum.Enum):
|
||||
class AccountStatus(enum.StrEnum):
|
||||
PENDING = "pending"
|
||||
UNINITIALIZED = "uninitialized"
|
||||
ACTIVE = "active"
|
||||
@ -108,6 +108,10 @@ class Account(UserMixin, db.Model):
|
||||
def is_admin_or_owner(self):
|
||||
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
|
||||
|
||||
@property
|
||||
def is_admin(self):
|
||||
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
|
||||
|
||||
@property
|
||||
def is_editor(self):
|
||||
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
|
||||
@ -121,12 +125,12 @@ class Account(UserMixin, db.Model):
|
||||
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
|
||||
|
||||
|
||||
class TenantStatus(str, enum.Enum):
|
||||
class TenantStatus(enum.StrEnum):
|
||||
NORMAL = "normal"
|
||||
ARCHIVE = "archive"
|
||||
|
||||
|
||||
class TenantAccountRole(str, enum.Enum):
|
||||
class TenantAccountRole(enum.StrEnum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
EDITOR = "editor"
|
||||
@ -147,6 +151,10 @@ class TenantAccountRole(str, enum.Enum):
|
||||
def is_privileged_role(role: str) -> bool:
|
||||
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||
|
||||
@staticmethod
|
||||
def is_admin_role(role: str) -> bool:
|
||||
return role and role == TenantAccountRole.ADMIN
|
||||
|
||||
@staticmethod
|
||||
def is_non_owner_role(role: str) -> bool:
|
||||
return role and role in {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user