mirror of
https://github.com/langgenius/dify.git
synced 2026-01-23 21:42:53 +08:00
Compare commits
145 Commits
feat/suppo
...
0.13.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 0ff8bd2aa9 | |||
| 2866383228 | |||
| 00ac7edeb3 | |||
| 537068cfde | |||
| c3c6a48059 | |||
| 5c166b3f40 | |||
| 230fa3286b | |||
| 061c0b10fd | |||
| 32f8a98cf8 | |||
| 6c60ecb237 | |||
| c3fae5e801 | |||
| a594e256ae | |||
| 41d90c2408 | |||
| 7ff42b1b7a | |||
| 4d7cfd0de5 | |||
| 266d32bd77 | |||
| 7e1184c071 | |||
| 1ce51e57ab | |||
| 142b4fd699 | |||
| cc8feaa483 | |||
| d9d5d35a77 | |||
| 9277156b6c | |||
| 1490a19fa1 | |||
| 9b7adcd4d9 | |||
| a8d32f9964 | |||
| 5093337de1 | |||
| f54225568c | |||
| 255ff446ba | |||
| 9a0dc4bfdc | |||
| 9d975750bc | |||
| 7c979e6490 | |||
| d60ca1661c | |||
| bb62391a4c | |||
| 0b25c0b677 | |||
| a5d6082418 | |||
| 631cbcd781 | |||
| 20c4633d2a | |||
| 5669cac16d | |||
| 6180762160 | |||
| 376726cf90 | |||
| 284bb7ac71 | |||
| eca466bdaa | |||
| d56abec195 | |||
| 961e25f608 | |||
| 138bf698b0 | |||
| e5bb4cca12 | |||
| 5e2cb0e3a8 | |||
| 16a65cb367 | |||
| 1bae9b8ff7 | |||
| d7c1f43b49 | |||
| f933af9f57 | |||
| 91e1ff5e30 | |||
| 5908e10549 | |||
| 464e6354c5 | |||
| d470e55f8c | |||
| 98a1b01b0c | |||
| e240424be5 | |||
| 1cb5a12abb | |||
| ff2a4a6fcd | |||
| c58d2fce89 | |||
| 7a962b9f03 | |||
| a679079a1d | |||
| e39e776d03 | |||
| e135ffc2c1 | |||
| e79eac688a | |||
| 643a90c48d | |||
| 2a448a899d | |||
| 7b86f8f024 | |||
| e686f12317 | |||
| a86f1eca79 | |||
| 668c1c0792 | |||
| c4fad66f2a | |||
| 02572e8cca | |||
| 1d8385f7ac | |||
| f8c966c39c | |||
| 3c8efe7c0a | |||
| dbc10e0feb | |||
| 239bf97b47 | |||
| 858db2f239 | |||
| c34bdb74e6 | |||
| 9601102885 | |||
| 56c2d1cc55 | |||
| a67b0d4771 | |||
| ef204817ae | |||
| 9bc5bc2548 | |||
| fd4be36991 | |||
| 9b46b02717 | |||
| 3bc4dc58d7 | |||
| 594666eb61 | |||
| e80f41a701 | |||
| f9c2aa7689 | |||
| 9dd4bf5574 | |||
| 5a9b785773 | |||
| d96a28487a | |||
| 0554898b5d | |||
| 6f9ce6a199 | |||
| e3119112a6 | |||
| d3af0e9090 | |||
| 2feb44e2c5 | |||
| cc0b92bc75 | |||
| e576d32fb6 | |||
| 2d6865d421 | |||
| 0f1133729f | |||
| d7160ee563 | |||
| 18add94a31 | |||
| 18d3ffc194 | |||
| 0a30a5b077 | |||
| 9049dd7725 | |||
| 6f418da388 | |||
| 41c6bf5fe4 | |||
| 33d6d26bbf | |||
| 787285d58f | |||
| 40fc6f529e | |||
| baef18cedd | |||
| a918cea2fe | |||
| 9789905a1f | |||
| f458580dee | |||
| 223a30401c | |||
| 2927493cf3 | |||
| 79db920fa7 | |||
| b3d65cc7df | |||
| 208d6d6d94 | |||
| aa135a3780 | |||
| 044e7b63c2 | |||
| 5b7b328193 | |||
| 8d5a1be227 | |||
| 90d5765fb6 | |||
| 1db14793fa | |||
| cbb4e95928 | |||
| 20c091a5e7 | |||
| e9c098d024 | |||
| 9f75970347 | |||
| f1366e8e19 | |||
| 0f85e3557b | |||
| 17ee731546 | |||
| af2461cccc | |||
| 60c1549771 | |||
| ab6dcf7032 | |||
| 8aae235a71 | |||
| c032574491 | |||
| 1065917872 | |||
| 56e361ac44 | |||
| 2e00829b1e | |||
| 625aaceb00 | |||
| 98d85e6b74 |
13
.github/pull_request_template.md
vendored
13
.github/pull_request_template.md
vendored
@ -8,16 +8,9 @@ Please include a summary of the change and which issue is fixed. Please also inc
|
||||
|
||||
# Screenshots
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>Before: </td>
|
||||
<td>After: </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>...</td>
|
||||
<td>...</td>
|
||||
</tr>
|
||||
</table>
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
|
||||
# Checklist
|
||||
|
||||
|
||||
2
.github/workflows/db-migration-test.yml
vendored
2
.github/workflows/db-migration-test.yml
vendored
@ -48,6 +48,8 @@ jobs:
|
||||
cp .env.example .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: |
|
||||
cd api
|
||||
poetry run python -m flask upgrade-db
|
||||
|
||||
@ -147,6 +147,13 @@ Deploy Dify to Cloud Platform with a single click using [terraform](https://www.
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Using AWS CDK for Deployment
|
||||
|
||||
Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
14
README_AR.md
14
README_AR.md
@ -190,6 +190,13 @@ docker compose up -d
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform بواسطة @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### استخدام AWS CDK للنشر
|
||||
|
||||
انشر Dify على AWS باستخدام [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## المساهمة
|
||||
|
||||
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.
|
||||
@ -222,3 +229,10 @@ docker compose up -d
|
||||
## الرخصة
|
||||
|
||||
هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية.
|
||||
## الكشف عن الأمان
|
||||
|
||||
لحماية خصوصيتك، يرجى تجنب نشر مشكلات الأمان على GitHub. بدلاً من ذلك، أرسل أسئلتك إلى security@dify.ai وسنقدم لك إجابة أكثر تفصيلاً.
|
||||
|
||||
## الرخصة
|
||||
|
||||
هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية.
|
||||
|
||||
@ -213,6 +213,13 @@ docker compose up -d
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### 使用 AWS CDK 部署
|
||||
|
||||
使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
14
README_ES.md
14
README_ES.md
@ -215,6 +215,13 @@ Despliega Dify en una plataforma en la nube con un solo clic utilizando [terrafo
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Usando AWS CDK para el Despliegue
|
||||
|
||||
Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
@ -248,3 +255,10 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En
|
||||
## Licencia
|
||||
|
||||
Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales.
|
||||
## Divulgación de Seguridad
|
||||
|
||||
Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada.
|
||||
|
||||
## Licencia
|
||||
|
||||
Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales.
|
||||
|
||||
14
README_FR.md
14
README_FR.md
@ -213,6 +213,13 @@ Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](http
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform par @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Utilisation d'AWS CDK pour le déploiement
|
||||
|
||||
Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contribuer
|
||||
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
@ -246,3 +253,10 @@ Pour protéger votre vie privée, veuillez éviter de publier des problèmes de
|
||||
## Licence
|
||||
|
||||
Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires.
|
||||
## Divulgation de sécurité
|
||||
|
||||
Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée.
|
||||
|
||||
## Licence
|
||||
|
||||
Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires.
|
||||
|
||||
@ -212,6 +212,13 @@ docker compose up -d
|
||||
##### Google Cloud
|
||||
- [@sotazumによるGoogle Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### AWS CDK を使用したデプロイ
|
||||
|
||||
[CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします
|
||||
|
||||
##### AWS
|
||||
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## 貢献
|
||||
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
|
||||
|
||||
@ -213,6 +213,13 @@ wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform qachlot @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### AWS CDK atorlugh pilersitsineq
|
||||
|
||||
wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH.
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@ -205,6 +205,13 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
||||
##### Google Cloud
|
||||
- [sotazum의 Google Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### AWS CDK를 사용한 배포
|
||||
|
||||
[CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포
|
||||
|
||||
##### AWS
|
||||
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
|
||||
@ -211,6 +211,13 @@ Implante o Dify na Plataforma Cloud com um único clique usando [terraform](http
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Usando AWS CDK para Implantação
|
||||
|
||||
Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contribuindo
|
||||
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@ -145,6 +145,13 @@ namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Uporaba AWS CDK za uvajanje
|
||||
|
||||
Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Prispevam
|
||||
|
||||
Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.
|
||||
|
||||
@ -211,6 +211,13 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform tarafından @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### AWS CDK ile Dağıtım
|
||||
|
||||
[CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
|
||||
|
||||
@ -207,6 +207,13 @@ Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột b
|
||||
##### Google Cloud
|
||||
- [Google Cloud Terraform bởi @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Sử dụng AWS CDK để Triển khai
|
||||
|
||||
Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Đóng góp
|
||||
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
|
||||
|
||||
@ -329,6 +329,7 @@ NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
UNSTRUCTURED_API_KEY=
|
||||
SCARF_NO_ANALYTICS=true
|
||||
|
||||
#ssrf
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
@ -382,7 +383,7 @@ LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
LOG_TZ=UTC
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
|
||||
# Workflow runtime configuration
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
@ -410,4 +411,5 @@ POSITION_PROVIDER_EXCLUDES=
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
|
||||
96
api/.ruff.toml
Normal file
96
api/.ruff.toml
Normal file
@ -0,0 +1,96 @@
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
quote-style = "double"
|
||||
|
||||
[lint]
|
||||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"PLC0208", # iteration-over-set
|
||||
"PLC2801", # unnecessary-dunder-call
|
||||
"PLC0414", # useless-import-alias
|
||||
"PLE0604", # invalid-all-object
|
||||
"PLE0605", # invalid-all-format
|
||||
"PLR0402", # manual-from-import
|
||||
"PLR1711", # useless-return
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"E731", # lambda-assignment
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"FURB113", # repeated-append
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"F401", # unused-import
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
extend-generics = [
|
||||
"_pytest.monkeypatch",
|
||||
"tests.integration_tests",
|
||||
]
|
||||
@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-7 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
108
api/app.py
108
api/app.py
@ -1,113 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
python_version = sys.version_info
|
||||
if not ((3, 11) <= python_version < (3, 13)):
|
||||
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
|
||||
raise SystemExit(1)
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from flask import Response
|
||||
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers # noqa: F401
|
||||
from extensions.ext_database import db
|
||||
|
||||
# TODO: Find a way to avoid importing models here
|
||||
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
os.environ["TZ"] = "UTC"
|
||||
# windows platform not support tzset
|
||||
if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if dify_config.TESTING:
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/threads")
|
||||
def threads():
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
|
||||
thread_list = []
|
||||
for thread in threads:
|
||||
thread_name = thread.name
|
||||
thread_id = thread.ident
|
||||
is_alive = thread.is_alive()
|
||||
|
||||
thread_list.append(
|
||||
{
|
||||
"name": thread_name,
|
||||
"id": thread_id,
|
||||
"is_alive": is_alive,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"thread_num": num_threads,
|
||||
"threads": thread_list,
|
||||
}
|
||||
|
||||
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat():
|
||||
engine = db.engine
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in_connections": engine.pool.checkedin(),
|
||||
"checked_out_connections": engine.pool.checkedout(),
|
||||
"overflow_connections": engine.pool.overflow(),
|
||||
"connection_timeout": engine.pool.timeout(),
|
||||
"recycle_time": db.engine.pool._recycle,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@ -1,54 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import contexts
|
||||
from commands import register_commands
|
||||
from configs import dify_config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_sentry,
|
||||
ext_storage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class DifyApp(Flask):
|
||||
pass
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
# ----------------------------
|
||||
def create_flask_app_with_configs() -> Flask:
|
||||
def create_flask_app_with_configs() -> DifyApp:
|
||||
"""
|
||||
create a raw flask app
|
||||
with configs loaded from .env file
|
||||
@ -68,111 +29,72 @@ def create_flask_app_with_configs() -> Flask:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> Flask:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_logging.init_app(app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
ext_redis.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_hosting_provider.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_proxy_fix.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if logged_in_account:
|
||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||
return logged_in_account
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
def initialize_extensions(app: DifyApp):
|
||||
from extensions import (
|
||||
ext_app_metrics,
|
||||
ext_blueprints,
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_sentry,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
ext_timezone,
|
||||
ext_warnings,
|
||||
)
|
||||
|
||||
extensions = [
|
||||
ext_timezone,
|
||||
ext_logging,
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
ext_database,
|
||||
ext_app_metrics,
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_celery,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_hosting_provider,
|
||||
ext_sentry,
|
||||
ext_proxy_fix,
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
|
||||
if not is_enabled:
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Skipped {short_name}")
|
||||
continue
|
||||
|
||||
# register blueprint routers
|
||||
def register_blueprints(app):
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
start_time = time.perf_counter()
|
||||
ext.init_app(app)
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
|
||||
@ -259,7 +259,7 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
upper_colletion_vector_types = {
|
||||
upper_collection_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.RELYT,
|
||||
@ -267,7 +267,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
}
|
||||
lower_colletion_vector_types = {
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
@ -307,7 +307,7 @@ def migrate_knowledge_vector_database():
|
||||
continue
|
||||
collection_name = ""
|
||||
dataset_id = dataset.id
|
||||
if vector_type in upper_colletion_vector_types:
|
||||
if vector_type in upper_collection_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
@ -323,7 +323,7 @@ def migrate_knowledge_vector_database():
|
||||
else:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
elif vector_type in lower_colletion_vector_types:
|
||||
elif vector_type in lower_collection_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
@ -640,15 +640,3 @@ where sites.id is null limit 1000"""
|
||||
break
|
||||
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
app.cli.add_command(convert_to_agent_apps)
|
||||
app.cli.add_command(add_qdrant_doc_id_index)
|
||||
app.cli.add_command(create_tenant)
|
||||
app.cli.add_command(upgrade_db)
|
||||
app.cli.add_command(fix_app_site_missing)
|
||||
|
||||
@ -17,11 +17,6 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
TESTING: bool = Field(
|
||||
description="Enable testing mode for running automated tests",
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
|
||||
@ -585,6 +585,11 @@ class RagEtlConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
SCARF_NO_ANALYTICS: Optional[str] = Field(
|
||||
description="This is about whether to disable Scarf analytics in Unstructured library.",
|
||||
default="false",
|
||||
)
|
||||
|
||||
|
||||
class DataSetConfig(BaseSettings):
|
||||
"""
|
||||
@ -640,7 +645,7 @@ class IndexingConfig(BaseSettings):
|
||||
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
|
||||
description="Maximum token length for text segmentation during indexing",
|
||||
default=1000,
|
||||
default=4000,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.11.2",
|
||||
default="0.13.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -18,6 +18,7 @@ language_timezone_mapping = {
|
||||
"tr-TR": "Europe/Istanbul",
|
||||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@ -190,7 +190,7 @@ class AppCopyApi(Resource):
|
||||
)
|
||||
session.commit()
|
||||
|
||||
stmt = select(App).where(App.id == result.app.id)
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -100,11 +100,11 @@ class DraftWorkflowApi(Resource):
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
@ -382,7 +382,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
filters = None
|
||||
if args.get("q"):
|
||||
try:
|
||||
filters = json.loads(args.get("q"))
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ class OAuthDataSource(Resource):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
|
||||
|
||||
@ -52,7 +52,6 @@ class OAuthLogin(Resource):
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
|
||||
@ -106,6 +106,7 @@ class GetProcessRuleApi(Resource):
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES["mode"]
|
||||
rules = DocumentService.DEFAULT_RULES["rules"]
|
||||
limits = DocumentService.DEFAULT_RULES["limits"]
|
||||
if document_id:
|
||||
# get the latest process rule
|
||||
document = Document.query.get_or_404(document_id)
|
||||
@ -132,7 +133,7 @@ class GetProcessRuleApi(Resource):
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
|
||||
return {"mode": mode, "rules": rules}
|
||||
return {"mode": mode, "rules": rules, "limits": limits}
|
||||
|
||||
|
||||
class DatasetDocumentListApi(Resource):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
@ -20,8 +21,17 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
|
||||
@ -368,6 +368,7 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
labels=args["labels"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -48,7 +48,8 @@ class AppInfoApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app information"""
|
||||
return {"name": app_model.name, "description": app_model.description}
|
||||
tags = [tag.name for tag in app_model.tags]
|
||||
return {"name": app_model.name, "description": app_model.description, "tags": tags}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
@ -36,7 +39,7 @@ class ModelConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
|
||||
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors.
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@ import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
@ -33,37 +34,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
_dialogue_count: int
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -127,12 +108,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
@ -146,12 +129,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -180,7 +163,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
@ -195,7 +178,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=None,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -207,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -231,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# get conversation dialogue count
|
||||
self._dialogue_count = get_thread_messages_length(conversation.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -301,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
dialogue_count=self._dialogue_count,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
@ -354,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
dialogue_count=self._dialogue_count,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
self._dialogue_count = dialogue_count
|
||||
|
||||
def run(self) -> None:
|
||||
app_config = self.application_generate_entity.app_config
|
||||
@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
|
||||
@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:param message: message
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
:param dialogue_count: dialogue count
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@ -358,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@ -371,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
@ -85,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -28,34 +28,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[dict, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -65,7 +46,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not stream:
|
||||
if not streaming:
|
||||
raise ValueError("Agent Chat App does not support blocking mode")
|
||||
|
||||
if not args.get("query"):
|
||||
@ -96,7 +77,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = AgentChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args["model_config"],
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
@ -134,12 +116,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
@ -180,7 +164,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -14,8 +14,10 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
cls,
|
||||
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
|
||||
invoke_from: InvokeFrom,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
@ -80,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"segment_id": resource["segment_id"],
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
|
||||
@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
streaming: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -132,7 +132,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
@ -140,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
@ -177,7 +179,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@ -50,7 +50,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -114,12 +114,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
@ -158,7 +158,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -30,43 +30,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@ -101,7 +76,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
@ -115,7 +90,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
@ -127,20 +102,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -169,14 +133,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -203,7 +173,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
@ -218,7 +188,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
stream=stream,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
||||
@ -106,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -319,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@ -332,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
|
||||
@ -43,7 +43,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
@ -138,7 +138,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -88,20 +88,17 @@ class RateLimit:
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str):
|
||||
if isinstance(generator, Mapping):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -80,38 +81,38 @@ class WorkflowCycleManage:
|
||||
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
if workflow_run_id:
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = WorkflowRun()
|
||||
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
workflow_run.id = system_id or str(uuid4())
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
@ -339,7 +340,7 @@ class WorkflowCycleManage:
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
|
||||
@ -7,13 +7,13 @@ from .models import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FILE_MODEL_IDENTITY",
|
||||
"ArrayFileAttribute",
|
||||
"File",
|
||||
"FileAttribute",
|
||||
"FileBelongsTo",
|
||||
"FileTransferMethod",
|
||||
"FileType",
|
||||
"FileUploadConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
|
||||
@ -53,8 +53,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
if stream:
|
||||
return response.iter_bytes()
|
||||
return response
|
||||
else:
|
||||
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||
|
||||
@ -15,6 +15,5 @@ class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
json_obj = json.loads(action_match.group(0).strip())
|
||||
else:
|
||||
json_obj = []
|
||||
print(f"Could not parse LLM output: {text}")
|
||||
|
||||
return json_obj
|
||||
|
||||
@ -91,7 +91,7 @@ class XinferenceProvider(Provider):
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应Erros,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
|
||||
@ -18,25 +18,25 @@ from .message_entities import (
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"ImagePromptMessageContent",
|
||||
"VideoPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageRole",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"AssistantPromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageRole",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"UserPromptMessage",
|
||||
"PromptMessageTool",
|
||||
"ToolPromptMessage",
|
||||
"PromptMessageContentType",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"ToolPromptMessage",
|
||||
"UserPromptMessage",
|
||||
"VideoPromptMessageContent",
|
||||
]
|
||||
|
||||
@ -453,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
@ -461,7 +461,15 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content = message.content.strip()
|
||||
if isinstance(message.content, str):
|
||||
message.content = message.content.strip()
|
||||
elif isinstance(message.content, list):
|
||||
# System prompt only support text
|
||||
message.content = "".join(
|
||||
c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown system prompt message content type {type(message.content)}")
|
||||
if first_loop:
|
||||
system = message.content
|
||||
first_loop = False
|
||||
@ -475,6 +483,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
# handle empty user prompt see #10013 #10520
|
||||
# responses, ignore user prompts containing only whitespace, the Claude API can't handle it.
|
||||
if not message.content.strip():
|
||||
continue
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
prompt_message_dicts.append(message_dict)
|
||||
else:
|
||||
|
||||
@ -779,7 +779,7 @@ LLM_BASE_MODELS = [
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
|
||||
@ -598,6 +598,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
# message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# fix azure when enable json schema cant process content = "" in assistant fix with None
|
||||
if not message.content:
|
||||
message_dict["content"] = None
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
|
||||
@ -14,7 +14,7 @@ from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_M
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
||||
@ -16,6 +16,7 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
model: amazon.nova-lite-v1:0
|
||||
label:
|
||||
en_US: Nova Lite V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
output: '0.00024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,52 @@
|
||||
model: amazon.nova-micro-v1:0
|
||||
label:
|
||||
en_US: Nova Micro V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.000035'
|
||||
output: '0.00014'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,53 @@
|
||||
model: amazon.nova-pro-v1:0
|
||||
label:
|
||||
en_US: Nova Pro V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.0008'
|
||||
output: '0.0032'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -70,6 +70,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "amazon.nova", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "us.amazon.nova", "support_system_prompts": True, "support_tool_use": False},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -194,6 +196,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
if model_info["support_tool_use"] and tools:
|
||||
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
|
||||
try:
|
||||
# for issue #10976
|
||||
conversations_list = parameters["messages"]
|
||||
# if two consecutive user messages found, combine them into one message
|
||||
for i in range(len(conversations_list) - 2, -1, -1):
|
||||
if conversations_list[i]["role"] == conversations_list[i + 1]["role"]:
|
||||
conversations_list[i]["content"].extend(conversations_list.pop(i + 1)["content"])
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
model: us.amazon.nova-lite-v1:0
|
||||
label:
|
||||
en_US: Nova Lite V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
output: '0.00024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,52 @@
|
||||
model: us.amazon.nova-micro-v1:0
|
||||
label:
|
||||
en_US: Nova Micro V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.000035'
|
||||
output: '0.00014'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,53 @@
|
||||
model: us.amazon.nova-pro-v1:0
|
||||
label:
|
||||
en_US: Nova Pro V1 (US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
parameter_rules:
|
||||
- name: max_new_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 5000
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.0008'
|
||||
output: '0.0032'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,2 @@
|
||||
- amazon.rerank-v1
|
||||
- cohere.rerank-v3-5
|
||||
@ -0,0 +1,4 @@
|
||||
model: amazon.rerank-v1:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@ -0,0 +1,4 @@
|
||||
model: cohere.rerank-v3-5:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
147
api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
Normal file
147
api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
Normal file
@ -0,0 +1,147 @@
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class BedrockRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-agent-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", ""),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
||||
text_sources = []
|
||||
for text in docs:
|
||||
text_sources.append(
|
||||
{
|
||||
"type": "INLINE",
|
||||
"inlineDocumentSource": {
|
||||
"type": "TEXT",
|
||||
"textDocument": {
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": top_n,
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
},
|
||||
}
|
||||
response = bedrock_runtime.rerank(
|
||||
queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(response["results"]):
|
||||
# format document
|
||||
index = result["index"]
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=docs[index],
|
||||
score=result["relevanceScore"],
|
||||
)
|
||||
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
@ -2,3 +2,4 @@
|
||||
- rerank-english-v3.0
|
||||
- rerank-multilingual-v2.0
|
||||
- rerank-multilingual-v3.0
|
||||
- rerank-v3.5
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
model: rerank-v3.5
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@ -32,12 +32,12 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: str, model_parameters: dict) -> None:
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
|
||||
if model is None:
|
||||
model = "bge-large-zh-v1.5"
|
||||
model = "Qwen2-72B-Instruct"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
@ -47,5 +47,7 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
|
||||
schema = self.get_model_schema(model, credentials)
|
||||
assert schema is not None, f"Model schema not found for model {model}"
|
||||
assert schema.features is not None, f"Model features not found for model {model}"
|
||||
if ModelFeature.TOOL_CALL in schema.features or ModelFeature.MULTI_TOOL_CALL in schema.features:
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
|
||||
@ -122,7 +122,7 @@ class GiteeAIRerankModel(RerankModel):
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -10,7 +10,7 @@ from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -254,8 +254,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if response.usage_metadata:
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
completion_tokens = response.usage_metadata.candidates_token_count
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
@ -140,7 +140,7 @@ class GPUStackRerankModel(RerankModel):
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -128,7 +128,7 @@ class JinaRerankModel(RerankModel):
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -193,7 +193,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@ -181,9 +181,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {"model": model, "stream": stream}
|
||||
|
||||
if "format" in model_parameters:
|
||||
data["format"] = model_parameters["format"]
|
||||
del model_parameters["format"]
|
||||
if format_schema := model_parameters.pop("format", None):
|
||||
try:
|
||||
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
|
||||
|
||||
if "keep_alive" in model_parameters:
|
||||
data["keep_alive"] = model_parameters["keep_alive"]
|
||||
@ -733,12 +735,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
|
||||
type=ParameterType.STRING,
|
||||
type=ParameterType.TEXT,
|
||||
default="json",
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in. Currently the only accepted value is json.",
|
||||
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
|
||||
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
|
||||
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
|
||||
@ -139,7 +139,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
||||
@ -943,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
if isinstance(message.content, list):
|
||||
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
|
||||
message.content = "".join(c.data for c in text_contents)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
|
||||
@ -11,7 +11,7 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
||||
@ -462,7 +462,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ supported_model_types:
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- rerank
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@ -67,7 +68,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -80,7 +81,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -93,7 +94,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: rerank
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -104,7 +105,7 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
@ -174,3 +175,19 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: '\n\n'
|
||||
type: text-input
|
||||
- variable: voices
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Available Voices (comma-separated)
|
||||
zh_Hans: 可用声音(用英文逗号分隔)
|
||||
type: text-input
|
||||
required: false
|
||||
default: "alloy"
|
||||
placeholder:
|
||||
en_US: "alloy,echo,fable,onyx,nova,shimmer"
|
||||
zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
|
||||
help:
|
||||
en_US: "List voice names separated by commas. First voice will be used as default."
|
||||
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
|
||||
|
||||
@ -139,13 +139,17 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
endpoint_url = credentials.get("endpoint_url", "")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
payload = {"input": "ping", "model": model}
|
||||
# For nvidia models, the "input_type":"query" need in the payload
|
||||
# more to check issue #11193 or NvidiaTextEmbeddingModel
|
||||
if model.startswith("nvidia/"):
|
||||
payload["input_type"] = "query"
|
||||
|
||||
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
@ -176,7 +180,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
||||
@ -0,0 +1,145 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI-compatible text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke TTS model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model voice/speaker
|
||||
:param user: unique user id
|
||||
:return: audio data as bytes iterator
|
||||
"""
|
||||
# Set up headers with authentication if provided
|
||||
headers = {}
|
||||
if api_key := credentials.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Construct endpoint URL
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
endpoint_url = urljoin(endpoint_url, "audio/speech")
|
||||
|
||||
# Get audio format from model properties
|
||||
audio_format = self._get_model_audio_type(model, credentials)
|
||||
|
||||
# Split text into chunks if needed based on word limit
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
sentences = self._split_text_into_sentences(content_text, word_limit)
|
||||
|
||||
for sentence in sentences:
|
||||
# Prepare request payload
|
||||
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
|
||||
|
||||
# Make POST request
|
||||
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
|
||||
# Stream the audio data
|
||||
for chunk in response.iter_content(chunk_size=4096):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# Get default voice for validation
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
|
||||
# Test with a simple text
|
||||
next(
|
||||
self._invoke(
|
||||
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
"""
|
||||
# Parse voices from comma-separated string
|
||||
voice_names = credentials.get("voices", "alloy").strip().split(",")
|
||||
voices = []
|
||||
|
||||
for voice in voice_names:
|
||||
voice = voice.strip()
|
||||
if not voice:
|
||||
continue
|
||||
|
||||
# Use en-US for all voices
|
||||
voices.append(
|
||||
{
|
||||
"name": voice,
|
||||
"mode": voice,
|
||||
"language": "en-US",
|
||||
}
|
||||
)
|
||||
|
||||
# If no voices provided or all voices were empty strings, use 'alloy' as default
|
||||
if not voices:
|
||||
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
|
||||
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
|
||||
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
|
||||
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
|
||||
ModelPropertyKey.VOICES: voices,
|
||||
},
|
||||
)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Override base get_tts_model_voices to handle customizable voices
|
||||
"""
|
||||
model_schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support voice")
|
||||
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
|
||||
# Always return all voices regardless of language
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
@ -182,7 +182,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
||||
@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@ -104,13 +104,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
token = ""
|
||||
|
||||
# get access token from service account credential
|
||||
if service_account_info:
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
@ -478,10 +479,11 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
|
||||
@ -48,10 +48,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
@ -100,10 +101,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_info:
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
@ -173,7 +175,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .common import ChatRole
|
||||
from .maas import MaasError, MaasService
|
||||
|
||||
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||
__all__ = ["ChatRole", "MaasError", "MaasService"]
|
||||
|
||||
@ -166,7 +166,7 @@ class VoyageTextEmbeddingModel(TextEmbeddingModel):
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -17,7 +17,13 @@ class WenxinRerank(_CommonWenxin):
|
||||
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
|
||||
access_token = self._get_access_token()
|
||||
url = f"{self.api_bases[model]}?access_token={access_token}"
|
||||
|
||||
# For issue #11252
|
||||
# for wenxin Rerank model top_n length should be equal or less than docs length
|
||||
if top_n is not None and top_n > len(docs):
|
||||
top_n = len(docs)
|
||||
# for wenxin Rerank model, query should not be an empty string
|
||||
if query == "":
|
||||
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
|
||||
try:
|
||||
response = httpx.post(
|
||||
url,
|
||||
@ -25,7 +31,11 @@ class WenxinRerank(_CommonWenxin):
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
data = response.json()
|
||||
# wenxin error handling
|
||||
if "error_code" in data:
|
||||
raise InternalServerError(data["error_msg"])
|
||||
return data
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
@ -69,6 +79,9 @@ class WenxinRerankModel(RerankModel):
|
||||
results = wenxin_rerank.rerank(model, query, docs, top_n)
|
||||
|
||||
rerank_documents = []
|
||||
if "results" not in results:
|
||||
raise ValueError("results key not found in response")
|
||||
|
||||
for result in results["results"]:
|
||||
index = result["index"]
|
||||
if "document" in result:
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
model: grok-beta
|
||||
label:
|
||||
en_US: Grok beta
|
||||
en_US: Grok Beta
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
|
||||
@ -0,0 +1,64 @@
|
||||
model: grok-vision-beta
|
||||
label:
|
||||
en_US: Grok Vision Beta
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -35,3 +35,5 @@ class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
credentials["stream_function_calling"] = "support"
|
||||
credentials["vision_support"] = "support"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user