mirror of
https://github.com/langgenius/dify.git
synced 2026-01-24 05:46:13 +08:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 80a322aaa2 | |||
| 82f7875a52 | |||
| 4637ddaa7f | |||
| 8d2269f762 | |||
| 5f03e66489 | |||
| a9c1f1a041 | |||
| 49cee773c5 | |||
| c78828ab7c | |||
| e90d3c29ab | |||
| 153807f243 | |||
| 5db0b56c5b | |||
| 404db1ae5b | |||
| 02c4b1af71 | |||
| aa11659062 | |||
| d4985fb3aa | |||
| 8815511ccb | |||
| 40fb4d16ef | |||
| c69f5b07ba | |||
| 56c90e212a | |||
| 0f14873255 | |||
| 0bb7569d46 | |||
| ec57922bb6 | |||
| 781d294f49 | |||
| f515af2232 | |||
| fe8191b899 | |||
| 4d2cd6703b | |||
| 292220c596 | |||
| 53f37a6704 | |||
| 75c1a82556 | |||
| c5b3777d93 | |||
| 678bbf8fe8 | |||
| 342607f4a4 | |||
| 5f4cdd66fa | |||
| 91942e37ff | |||
| 60913970dc | |||
| 82c42b9ec5 | |||
| 2a3d8c25bc | |||
| cee0c51dbb | |||
| fdbbdb706f | |||
| f6dfe23cf8 |
@ -411,7 +411,8 @@ def migrate_knowledge_vector_database():
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count}"
|
||||
f" segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.8.0",
|
||||
default="0.8.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
else:
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
|
||||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
||||
@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
|
||||
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
|
||||
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
@ -25,14 +25,17 @@ class DailyMessageStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -79,14 +82,17 @@ class DailyConversationStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -133,14 +139,17 @@ class DailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -187,16 +196,18 @@ class DailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||
sum(total_price) as total_price
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -245,16 +256,26 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
m.conversation_id,
|
||||
COUNT(m.id) AS message_count
|
||||
FROM
|
||||
conversations c
|
||||
JOIN
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at >= :start"
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at < :end"
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
LEFT JOIN conversations c on c.id=subquery.conversation_id
|
||||
GROUP BY date
|
||||
ORDER BY date"""
|
||||
GROUP BY m.conversation_id
|
||||
) subquery
|
||||
LEFT JOIN
|
||||
conversations c
|
||||
ON c.id = subquery.conversation_id
|
||||
GROUP BY
|
||||
date
|
||||
ORDER BY
|
||||
date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -307,17 +333,21 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||
WHERE m.app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
messages m
|
||||
LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at >= :start"
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at < :end"
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -369,16 +399,17 @@ class AverageResponseTimeStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -425,17 +456,20 @@ class TokensPerSecondStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
END as tokens_per_second
|
||||
FROM messages
|
||||
WHERE app_id = :app_id"""
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -448,7 +482,7 @@ WHERE app_id = :app_id"""
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -458,10 +492,10 @@ WHERE app_id = :app_id"""
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
@ -26,16 +26,18 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -86,16 +88,18 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -146,18 +150,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) as token_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -213,27 +217,31 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
AVG(sub.interactions) as interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(SELECT
|
||||
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM workflow_runs c
|
||||
WHERE c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
|
||||
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
||||
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
}
|
||||
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
|
||||
|
||||
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
|
||||
@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
|
||||
@ -4,7 +4,7 @@ from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
|
||||
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args["provider_name"] if args["provider_name"] else "",
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
args["parameters"],
|
||||
|
||||
@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
source = request.args.get("source")
|
||||
if source == "datasets":
|
||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||
|
||||
@ -84,14 +84,10 @@ class TextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@ -22,10 +23,12 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
||||
@ -83,14 +83,10 @@ class TextApi(WebApiResource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
|
||||
@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
|
||||
if not source or source != "sso":
|
||||
raise WebSSOAuthRequiredError()
|
||||
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO,
|
||||
# raise an error and redirect to normal passport login
|
||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if source and source == "sso":
|
||||
|
||||
@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
|
||||
@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
|
||||
@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
Thought:"""
|
||||
Thought:""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
@ -86,7 +87,8 @@ Action:
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class VariableEntityType(str, Enum):
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
|
||||
@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -21,7 +21,7 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
||||
@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
|
||||
@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@ -451,7 +451,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
|
||||
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
||||
@ -171,5 +171,5 @@ class AppQueueManager:
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
||||
@ -161,7 +161,7 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
@ -189,7 +189,7 @@ class AppRunner:
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
@ -238,7 +238,7 @@ class AppRunner:
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
),
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
@ -351,7 +351,7 @@ class AppRunner:
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else "",
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
||||
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
||||
@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@ -376,7 +376,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -412,14 +414,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse:
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -84,10 +84,12 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
|
||||
f"Inputs: " f"{jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
@ -114,14 +116,17 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
f"Inputs: " f"" f"{jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: " f"{jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
|
||||
@ -90,6 +90,7 @@ class MessageStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
@ -479,6 +480,7 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
text: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
@ -15,6 +15,7 @@ class Segment(BaseModel):
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
|
||||
@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
||||
@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
audio = publisher.check_and_get_audio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
||||
@ -153,14 +153,21 @@ class MessageCycleManage:
|
||||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse:
|
||||
def _message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
|
||||
@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
data_source_type=item.get("data_source_type"),
|
||||
segment_id=item.get("segment_id"),
|
||||
score=item.get("score") if "score" in item else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
||||
word_count=item.get("word_count") if "word_count" in item else None,
|
||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||
|
||||
@ -3,6 +3,7 @@ import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -63,8 +64,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding="utf-8") as f:
|
||||
position = int(f.read().strip())
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + ".py") not in file_names:
|
||||
|
||||
@ -188,7 +188,8 @@ class MessageFileParser:
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
def is_s3_presigned_url(url):
|
||||
|
||||
@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
class CodeExecutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -86,15 +86,16 @@ class CodeExecutor:
|
||||
),
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionException("Code execution service is unavailable")
|
||||
raise CodeExecutionError("Code execution service is unavailable")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
||||
f"Failed to execute code, got status code {response.status_code},"
|
||||
f" please check if the sandbox service is running"
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise CodeExecutionException(
|
||||
raise CodeExecutionError(
|
||||
"Failed to execute code, which is likely a network issue,"
|
||||
" please check if the sandbox service is running."
|
||||
f" ( Error: {str(e)} )"
|
||||
@ -103,15 +104,15 @@ class CodeExecutor:
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
raise CodeExecutionException("Failed to parse response")
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
|
||||
if (code := response.get("code")) != 0:
|
||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
raise CodeExecutionError(response.data.error)
|
||||
|
||||
return response.data.stdout or ""
|
||||
|
||||
@ -126,13 +127,13 @@ class CodeExecutor:
|
||||
"""
|
||||
template_transformer = cls.code_template_transformers.get(language)
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f"Unsupported language {language}")
|
||||
raise CodeExecutionError(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
||||
@ -14,7 +14,10 @@ class ToolParameterCache:
|
||||
def __init__(
|
||||
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
||||
self.cache_key = (
|
||||
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@ -78,8 +78,8 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -134,8 +134,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -192,8 +192,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -756,7 +756,7 @@ class IndexingRunner:
|
||||
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
||||
result = redis_client.get(indexing_cache_key)
|
||||
if result:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
@ -767,10 +767,10 @@ class IndexingRunner:
|
||||
"""
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
raise DocumentIsDeletedPausedError()
|
||||
|
||||
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
||||
|
||||
@ -875,9 +875,9 @@ class IndexingRunner:
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
class DocumentIsPausedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
class DocumentIsDeletedPausedError(Exception):
|
||||
pass
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
class OutputParserException(Exception):
|
||||
class OutputParserError(Exception):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserException
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import (
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
|
||||
raise ValueError("Expected 'opening_statement' to be a str.")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
|
||||
@ -59,24 +59,27 @@ User Input: yo, 你今天咋样?
|
||||
}
|
||||
|
||||
User Input:
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||
"(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step."
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
"Step 1: Understand and summarize the main content of this text.\n"
|
||||
"Step 2: What key information or concepts are mentioned in this text?\n"
|
||||
"Step 3: Decompose or combine multiple pieces of information and concepts.\n"
|
||||
"Step 4: Generate questions and answers based on these key information and concepts.\n"
|
||||
"<Constraints> The questions should be clear and detailed, and the answers should be detailed and complete. "
|
||||
"You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n"
|
||||
"You must answer in {language}, in a style that is clear and detailed in {language}."
|
||||
" No language other than {language} should be used. \n"
|
||||
"<Format> Use the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
"<QA Pairs>"
|
||||
)
|
||||
@ -94,7 +97,7 @@ Based on task description, please create a well-structured prompt template that
|
||||
- Use the same language as task description.
|
||||
- Output in ``` xml ``` and start with <instruction>
|
||||
Please generate the full prompt template with at least 300 words and output only the prompt template.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
|
||||
Here is a task description for which I would like you to create a high-quality prompt template for:
|
||||
@ -109,7 +112,7 @@ Based on task description, please create a well-structured prompt template that
|
||||
- Use the same language as task description.
|
||||
- Output in ``` xml ``` and start with <instruction>
|
||||
Please generate the full prompt template and output only the prompt template.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """
|
||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||
@ -134,7 +137,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
|
||||
|
||||
### Answer
|
||||
I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """
|
||||
<instruction>
|
||||
@ -150,4 +153,4 @@ Welcome! I'm here to assist you with any questions or issues you might have with
|
||||
Here is the task description: {{INPUT_TEXT}}
|
||||
|
||||
You just need to generate the output
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
@ -39,7 +39,7 @@ class TokenBufferMemory:
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = message_limit if message_limit <= 500 else 500
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
|
||||
@ -8,8 +8,11 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions."
|
||||
" As the temperature approaches zero, the model will become deterministic and repetitive."
|
||||
" Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
|
||||
"较高的温度会导致更多的随机完成。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
@ -24,7 +27,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||
" are considered.",
|
||||
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
||||
},
|
||||
"required": False,
|
||||
@ -88,7 +92,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
|
||||
"en_US": "Specifies the upper limit on the length of generated results."
|
||||
" If the generated results are truncated, you can increase this parameter.",
|
||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||
},
|
||||
"required": False,
|
||||
@ -104,7 +109,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
"type": "string",
|
||||
"help": {
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||
" such as JSON, XML, etc.",
|
||||
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
||||
},
|
||||
"required": False,
|
||||
|
||||
@ -72,7 +72,9 @@ class AIModel(ABC):
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return invoke_error(
|
||||
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
|
||||
description=(
|
||||
f"[{provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
)
|
||||
)
|
||||
|
||||
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
||||
|
||||
@ -187,7 +187,7 @@ if you are not sure about the structure.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
@ -449,7 +449,7 @@ if you are not sure about the structure.
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
@ -830,7 +830,8 @@ if you are not sure about the structure.
|
||||
else:
|
||||
if parameter_value != round(parameter_value, parameter_rule.precision):
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
|
||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
|
||||
f" decimal places."
|
||||
)
|
||||
|
||||
# validate parameter value range
|
||||
|
||||
@ -51,7 +51,7 @@ if you are not sure about the structure.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
||||
chunk_text = chunk.delta.text or ""
|
||||
full_assistant_content += chunk_text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
|
||||
@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
|
||||
@ -16,6 +16,15 @@ from core.model_runtime.entities.model_entities import (
|
||||
|
||||
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
AZURE_DEFAULT_PARAM_SEED_HELP = I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,"
|
||||
"您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically,"
|
||||
" such that repeated requests with the same seed and parameters should return the same result."
|
||||
" Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter"
|
||||
" to monitor changes in the backend.",
|
||||
)
|
||||
|
||||
|
||||
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||
rule = ParameterRule(
|
||||
@ -229,10 +238,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -297,10 +303,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -365,10 +368,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -433,10 +433,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -502,10 +499,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -571,10 +565,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -650,10 +641,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -719,10 +707,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -788,10 +773,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -867,10 +849,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -936,10 +915,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
@ -1000,10 +976,7 @@ LLM_BASE_MODELS = [
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=I18nObject(
|
||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||
),
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
|
||||
@ -53,6 +53,12 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-08-01-preview
|
||||
value: 2024-08-01-preview
|
||||
- label:
|
||||
en_US: 2024-07-01-preview
|
||||
value: 2024-07-01-preview
|
||||
- label:
|
||||
en_US: 2024-05-01-preview
|
||||
value: 2024-05-01-preview
|
||||
|
||||
@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text if delta.text else ""
|
||||
text = delta.text or ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
real_model = chunk.model
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
completion += delta.delta.content if delta.delta.content else ""
|
||||
completion += delta.delta.content or ""
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
@ -420,7 +418,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
index += 0
|
||||
index += 1
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
||||
@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
)
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for index, future in enumerate(futures):
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
|
||||
@ -33,7 +33,7 @@ parameter_rules:
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -33,7 +33,7 @@ parameter_rules:
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -33,7 +33,7 @@ parameter_rules:
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -15,6 +15,7 @@ class BaichuanTokenizer:
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens(cls, text: str) -> int:
|
||||
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
||||
# tokens = number of Chinese characters + number of English words * 1.3
|
||||
# (for estimation only, subject to actual return)
|
||||
# https://platform.baichuan-ai.com/docs/text-Embedding
|
||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
||||
|
||||
@ -7,7 +7,7 @@ from requests import post
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -45,7 +45,7 @@ class BaichuanModel:
|
||||
parameters: dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if model in self._model_mapping.keys():
|
||||
if model in self._model_mapping:
|
||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||
# we need to rename it to res_format to get its value
|
||||
if parameters.get("res_format") == "json_object":
|
||||
@ -94,7 +94,7 @@ class BaichuanModel:
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
if model in self._model_mapping.keys():
|
||||
if model in self._model_mapping:
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
@ -124,7 +124,7 @@ class BaichuanModel:
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err == "invalid_request_error":
|
||||
|
||||
@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
||||
@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and "rate" in err:
|
||||
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
||||
@ -52,6 +52,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
|
||||
@ -52,6 +52,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
|
||||
@ -51,6 +51,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
||||
@ -51,6 +51,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
||||
@ -45,6 +45,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
||||
@ -45,6 +45,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
model: eu.anthropic.claude-3-haiku-20240307-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Haiku(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,58 @@
|
||||
model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,58 @@
|
||||
model: eu.anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Sonnet(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -20,6 +20,7 @@ from botocore.exceptions import (
|
||||
from PIL.Image import Image
|
||||
|
||||
# local import
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -44,6 +45,14 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
@ -52,6 +61,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
CONVERSE_API_ENABLED_MODEL_INFO = [
|
||||
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
|
||||
@ -70,6 +81,40 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||
return None
|
||||
|
||||
def _code_block_mode_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if model_parameters.get("response_format"):
|
||||
stop = stop or []
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
response_format = model_parameters.pop("response_format")
|
||||
format_prompt = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||
"{{block}}", response_format
|
||||
)
|
||||
)
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = format_prompt
|
||||
else:
|
||||
prompt_messages.insert(0, format_prompt)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
@ -288,10 +333,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
elif "contentBlockDelta" in chunk:
|
||||
delta = chunk["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
chunk_text = delta["text"] if delta["text"] else ""
|
||||
chunk_text = delta["text"] or ""
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else "",
|
||||
content=chunk_text or "",
|
||||
)
|
||||
index = chunk["contentBlockDelta"]["contentBlockIndex"]
|
||||
yield LLMResultChunk(
|
||||
@ -498,7 +543,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"max_tokens": 32,
|
||||
}
|
||||
elif "ai21" in model:
|
||||
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
|
||||
# ValidationException: Malformed input request: #/temperature: expected type: Number,
|
||||
# found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null,
|
||||
# please reformat your input and try again.
|
||||
required_params = {
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
@ -706,7 +753,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
elif model_prefix == "cohere":
|
||||
output = response_body.get("generations")[0].get("text")
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output or "")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
@ -783,7 +830,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content_delta if content_delta else "",
|
||||
content=content_delta or "",
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
model: us.anthropic.claude-3-haiku-20240307-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Haiku(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,59 @@
|
||||
model: us.anthropic.claude-3-opus-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Opus(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,58 @@
|
||||
model: us.anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,58 @@
|
||||
model: us.anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Sonnet(Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
if delta.delta.function_call:
|
||||
function_calls = [delta.delta.function_call]
|
||||
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
|
||||
@ -62,7 +62,7 @@ parameter_rules:
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -45,7 +45,7 @@ if you are not sure about the structure.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@ -54,7 +54,8 @@ class TeiHelper:
|
||||
|
||||
url = str(URL(server_url) / "info")
|
||||
|
||||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||
# this method is surrounded by a lock, and default requests may hang forever,
|
||||
# so we just set a Adapter with max_retries=3
|
||||
session = Session()
|
||||
session.mount("http://", HTTPAdapter(max_retries=3))
|
||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||
|
||||
@ -131,7 +131,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
{
|
||||
"Role": message.role.value,
|
||||
# fix set content = "" while tool_call request
|
||||
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
|
||||
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter
|
||||
# message:Messages Content and Contents not allowed empty at the same time.
|
||||
"Content": " ", # message.content if (message.content is not None) else "",
|
||||
"ToolCalls": dict_tool_calls,
|
||||
}
|
||||
|
||||
@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
delta = chunk.choices[0]
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
if delta.delta.function_call:
|
||||
function_calls = [delta.delta.function_call]
|
||||
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
|
||||
@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
elif message.function_call:
|
||||
@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=message.stop_reason or None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
|
||||
@ -93,7 +93,8 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
||||
Validate model credentials using requests to ensure compatibility with all providers following
|
||||
OpenAI's API standard.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
|
||||
@ -239,7 +239,8 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError(
|
||||
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||
"oci_config_content should be base64.b64encode("
|
||||
"'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||
)
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
@ -442,9 +443,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user