mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 04:06:13 +08:00
merge feat/plugins
This commit is contained in:
@ -105,6 +105,7 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
- claude-3-opus-20240229
|
||||
- claude-3-sonnet-20240229
|
||||
- claude-2.1
|
||||
- claude-instant-1.2
|
||||
- claude-2
|
||||
- claude-instant-1
|
||||
@ -0,0 +1,39 @@
|
||||
model: claude-3-5-sonnet-20241022
|
||||
label:
|
||||
en_US: claude-3-5-sonnet-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,245 @@
|
||||
provider: azure_openai
|
||||
label:
|
||||
en_US: Azure OpenAI Service Model
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#E3F0FF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Azure
|
||||
zh_Hans: 从 Azure 获取 API Key
|
||||
url:
|
||||
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Deployment Name
|
||||
zh_Hans: 部署名称
|
||||
placeholder:
|
||||
en_US: Enter your Deployment Name here, matching the Azure deployment name.
|
||||
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
|
||||
credential_form_schemas:
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
en_US: API Endpoint URL
|
||||
zh_Hans: API 域名
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: '在此输入您的 API 域名,如:https://example.com/xxx'
|
||||
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
|
||||
- variable: openai_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API key here
|
||||
- variable: openai_api_version
|
||||
label:
|
||||
zh_Hans: API 版本
|
||||
en_US: API Version
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-10-01-preview
|
||||
value: 2024-10-01-preview
|
||||
- label:
|
||||
en_US: 2024-09-01-preview
|
||||
value: 2024-09-01-preview
|
||||
- label:
|
||||
en_US: 2024-08-01-preview
|
||||
value: 2024-08-01-preview
|
||||
- label:
|
||||
en_US: 2024-07-01-preview
|
||||
value: 2024-07-01-preview
|
||||
- label:
|
||||
en_US: 2024-05-01-preview
|
||||
value: 2024-05-01-preview
|
||||
- label:
|
||||
en_US: 2024-04-01-preview
|
||||
value: 2024-04-01-preview
|
||||
- label:
|
||||
en_US: 2024-03-01-preview
|
||||
value: 2024-03-01-preview
|
||||
- label:
|
||||
en_US: 2024-02-15-preview
|
||||
value: 2024-02-15-preview
|
||||
- label:
|
||||
en_US: 2023-12-01-preview
|
||||
value: 2023-12-01-preview
|
||||
- label:
|
||||
en_US: '2024-02-01'
|
||||
value: '2024-02-01'
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
- variable: base_model_name
|
||||
label:
|
||||
en_US: Base Model
|
||||
zh_Hans: 基础模型
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: gpt-35-turbo
|
||||
value: gpt-35-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-0125
|
||||
value: gpt-35-turbo-0125
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-16k
|
||||
value: gpt-35-turbo-16k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4
|
||||
value: gpt-4
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-32k
|
||||
value: gpt-4-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o1-mini
|
||||
value: o1-mini
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o1-preview
|
||||
value: o1-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-mini
|
||||
value: gpt-4o-mini
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-mini-2024-07-18
|
||||
value: gpt-4o-mini-2024-07-18
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o
|
||||
value: gpt-4o
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-05-13
|
||||
value: gpt-4o-2024-05-13
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-08-06
|
||||
value: gpt-4o-2024-08-06
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
value: gpt-4-turbo-2024-04-09
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-0125-preview
|
||||
value: gpt-4-0125-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-1106-preview
|
||||
value: gpt-4-1106-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-vision-preview
|
||||
value: gpt-4-vision-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-instruct
|
||||
value: gpt-35-turbo-instruct
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: text-embedding-ada-002
|
||||
value: text-embedding-ada-002
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-small
|
||||
value: text-embedding-3-small
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-large
|
||||
value: text-embedding-3-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: whisper-1
|
||||
value: whisper-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
- label:
|
||||
en_US: tts-1
|
||||
value: tts-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
- label:
|
||||
en_US: tts-1-hd
|
||||
value: tts-1-hd
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
764
api/core/model_runtime/model_providers/azure_openai/llm/llm.py
Normal file
764
api/core/model_runtime/model_providers/azure_openai/llm/llm.py
Normal file
@ -0,0 +1,764 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
from openai.types import Completion
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
if not model_entity:
|
||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
||||
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
else:
|
||||
# text completion model, do not support tool calling
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
return self._num_tokens_from_string(credentials, content)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if "openai_api_base" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
||||
|
||||
if "openai_api_key" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
if model.startswith("o1"):
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=1,
|
||||
max_completion_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
client.completions.create(
|
||||
prompt="ping",
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
base_model_name = self._get_base_model_name(credentials)
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
full_text = ""
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text or ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# calculate num tokens
|
||||
if chunk.usage:
|
||||
# transform usage
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Clear illegal prompt messages for OpenAI API
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: cleaned prompt messages
|
||||
"""
|
||||
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
|
||||
|
||||
if model in checklist:
|
||||
# count how many user messages are there
|
||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||
if user_message_count > 1:
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
item.data
|
||||
if item.type == PromptMessageContentType.TEXT
|
||||
else "[IMAGE]"
|
||||
if item.type == PromptMessageContentType.IMAGE
|
||||
else ""
|
||||
for item in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
if model.startswith("o1"):
|
||||
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
||||
if system_message_count > 0:
|
||||
new_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message = UserPromptMessage(
|
||||
content=prompt_message.content,
|
||||
name=prompt_message.name,
|
||||
)
|
||||
|
||||
new_prompt_messages.append(prompt_message)
|
||||
prompt_messages = new_prompt_messages
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
assistant_message = response.choices[0].message
|
||||
assistant_message_tool_calls = assistant_message.tool_calls
|
||||
|
||||
# extract tool calls from response
|
||||
tool_calls = []
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=response.model or model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
index = 0
|
||||
full_assistant_content = ""
|
||||
real_model = model
|
||||
system_fingerprint = None
|
||||
completion = ""
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
|
||||
if delta.delta is None:
|
||||
continue
|
||||
|
||||
# extract tool calls from response
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
||||
|
||||
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
||||
if delta.finish_reason is None and not delta.delta.content:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
real_model = chunk.model
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
completion += delta.delta.content or ""
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_tool_calls(
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
|
||||
) -> None:
|
||||
if tool_calls_response:
|
||||
for response_tool_call in tool_calls_response:
|
||||
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
||||
index = response_tool_call.index
|
||||
if index < len(tool_calls):
|
||||
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
||||
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
||||
if response_tool_call.function:
|
||||
tool_calls[index].function.name = (
|
||||
response_tool_call.function.name or tool_calls[index].function.name
|
||||
)
|
||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
|
||||
else:
|
||||
assert response_tool_call.id is not None
|
||||
assert response_tool_call.type is not None
|
||||
assert response_tool_call.function is not None
|
||||
assert response_tool_call.function.name is not None
|
||||
assert response_tool_call.function.arguments is not None
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@staticmethod
|
||||
def _convert_prompt_message_to_dict(message: PromptMessage):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
assert message.content is not None
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
# message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"name": message.name,
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model = credentials["base_model_name"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-35-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
assert isinstance(tool_call, dict)
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode("function"))
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += len(encoding.encode("name"))
|
||||
num_tokens += len(encoding.encode(tool.name))
|
||||
num_tokens += len(encoding.encode("description"))
|
||||
num_tokens += len(encoding.encode(tool.description))
|
||||
parameters = tool.parameters
|
||||
num_tokens += len(encoding.encode("parameters"))
|
||||
if "title" in parameters:
|
||||
num_tokens += len(encoding.encode("title"))
|
||||
num_tokens += len(encoding.encode(parameters["title"]))
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode(parameters["type"]))
|
||||
if "properties" in parameters:
|
||||
num_tokens += len(encoding.encode("properties"))
|
||||
for key, value in parameters["properties"].items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if "required" in parameters:
|
||||
num_tokens += len(encoding.encode("required"))
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str):
|
||||
for ai_model_entity in LLM_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
def _get_base_model_name(self, credentials: dict) -> str:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError("Base Model Name is required")
|
||||
return base_model_name
|
||||
@ -0,0 +1,60 @@
|
||||
model: anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,60 @@
|
||||
model: us.anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 9.8 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M25.132 24.3947C25.497 25.7527 25.8984 27.1413 26.3334 28.5834C26.7302 29.8992 25.5459 30.4167 25.0752 29.1758C24.571 27.8466 24.0885 26.523 23.6347 25.1729C21.065 26.4654 18.5025 27.5424 15.5961 28.7541C16.7581 33.0256 17.8309 36.5984 19.4952 39.9935C19.4953 39.9936 19.4953 39.9937 19.4954 39.9938C19.6631 39.9979 19.8313 40 20 40C31.0457 40 40 31.0457 40 20C40 16.0335 38.8453 12.3366 36.8537 9.22729C31.6585 9.69534 27.0513 10.4562 22.8185 11.406C22.8882 12.252 22.9677 13.0739 23.0555 13.855C23.3824 16.7604 23.9112 19.5281 24.6137 22.3836C27.0581 21.2848 29.084 20.3225 30.6816 19.522C32.2154 18.7535 33.6943 18.7062 31.2018 20.6594C29.0388 22.1602 27.0644 23.3566 25.132 24.3947ZM36.1559 8.20846C33.0001 3.89184 28.1561 0.887462 22.5955 0.166882C22.4257 2.86234 22.4785 6.26344 22.681 9.50447C26.7473 8.88859 31.1721 8.46032 36.1559 8.20846ZM19.9369 9.73661e-05C19.7594 2.92694 19.8384 6.65663 20.19 9.91293C17.3748 10.4109 14.7225 11.0064 12.1592 11.7038C12.0486 10.4257 11.9927 9.25764 11.9927 8.24178C11.9927 7.5054 11.3957 6.90844 10.6593 6.90844C9.92296 6.90844 9.32601 7.5054 9.32601 8.24178C9.32601 9.47868 9.42873 10.898 9.61402 12.438C8.33567 12.8278 7.07397 13.2443 5.81918 13.688C5.12493 13.9336 4.76118 14.6954 5.0067 15.3896C5.25223 16.0839 6.01406 16.4476 6.7083 16.2021C7.7931 15.8185 8.88482 15.4388 9.98927 15.0659C10.5222 18.3344 11.3344 21.9428 12.2703 25.4156C12.4336 26.0218 12.6062 26.6262 12.7863 27.2263C9.34168 28.4135 5.82612 29.3782 2.61128 29.8879C0.949407 26.9716 0 23.5967 0 20C0 8.97534 8.92023 0.0341108 19.9369 9.73661e-05ZM4.19152 32.2527C7.45069 36.4516 12.3458 39.3173 17.9204 39.8932C16.5916 37.455 14.9338 33.717 13.5405 29.5901C10.4404 30.7762 7.25883 31.6027 4.19152 32.2527ZM22.9735 23.1135C22.1479 20.41 21.4462 17.5441 20.9225 14.277C20.746 13.5841 20.5918 12.8035 20.4593 11.9636C17.6508 12.6606 14.9992 13.4372 12.4356 14.2598C12.8479 17.4766 13.5448 21.1334 14.5118 24.7218C14.662 25.2792 14.8081 25.8248 14.9514 26.3594L14.9516 26.3603L14.9524 26.3634L14.9526 26.3639L14.973 26.4401C16.1833 25.9872 17.3746 25.5123 18.53 25.0259C20.1235 24.3552 21.6051 23.7165 22.9735 23.1135Z" fill="#141519"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
47
api/core/model_runtime/model_providers/gitee_ai/_common.py
Normal file
47
api/core/model_runtime/model_providers/gitee_ai/_common.py
Normal file
@ -0,0 +1,47 @@
|
||||
from dashscope.common.error import (
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
ServiceUnavailableError,
|
||||
UnsupportedHTTPMethod,
|
||||
UnsupportedModel,
|
||||
)
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonGiteeAI:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
RequestFailure,
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
ServiceUnavailableError,
|
||||
],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [
|
||||
AuthenticationError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvalidParameter,
|
||||
UnsupportedModel,
|
||||
UnsupportedHTTPMethod,
|
||||
],
|
||||
}
|
||||
25
api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py
Normal file
25
api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py
Normal file
@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteeAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
@ -0,0 +1,35 @@
|
||||
provider: gitee_ai
|
||||
label:
|
||||
en_US: Gitee AI
|
||||
zh_Hans: Gitee AI
|
||||
description:
|
||||
en_US: 快速体验大模型,领先探索 AI 开源世界
|
||||
zh_Hans: 快速体验大模型,领先探索 AI 开源世界
|
||||
icon_small:
|
||||
en_US: Gitee-AI-Logo.svg
|
||||
icon_large:
|
||||
en_US: Gitee-AI-Logo-full.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from Gitee AI
|
||||
zh_Hans: 从 Gitee AI 获取 token
|
||||
url:
|
||||
en_US: https://ai.gitee.com/dashboard/settings/tokens
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -0,0 +1,105 @@
|
||||
model: Qwen2-72B-Instruct
|
||||
label:
|
||||
zh_Hans: Qwen2-72B-Instruct
|
||||
en_US: Qwen2-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 6400
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: Qwen2-7B-Instruct
|
||||
label:
|
||||
zh_Hans: Qwen2-7B-Instruct
|
||||
en_US: Qwen2-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: Yi-1.5-34B-Chat
|
||||
label:
|
||||
zh_Hans: Yi-1.5-34B-Chat
|
||||
en_US: Yi-1.5-34B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,7 @@
|
||||
- Qwen2-7B-Instruct
|
||||
- Qwen2-72B-Instruct
|
||||
- Yi-1.5-34B-Chat
|
||||
- glm-4-9b-chat
|
||||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
@ -0,0 +1,105 @@
|
||||
model: codegeex4-all-9b
|
||||
label:
|
||||
zh_Hans: codegeex4-all-9b
|
||||
en_US: codegeex4-all-9b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 40960
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: deepseek-coder-33B-instruct-chat
|
||||
label:
|
||||
zh_Hans: deepseek-coder-33B-instruct-chat
|
||||
en_US: deepseek-coder-33B-instruct-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 9000
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,91 @@
|
||||
model: deepseek-coder-33B-instruct-completions
|
||||
label:
|
||||
zh_Hans: deepseek-coder-33B-instruct-completions
|
||||
en_US: deepseek-coder-33B-instruct-completions
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 9000
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@ -0,0 +1,105 @@
|
||||
model: glm-4-9b-chat
|
||||
label:
|
||||
zh_Hans: glm-4-9b-chat
|
||||
en_US: glm-4-9b-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: stream
|
||||
use_template: boolean
|
||||
label:
|
||||
en_US: "Stream"
|
||||
zh_Hans: "流式"
|
||||
type: boolean
|
||||
default: true
|
||||
required: true
|
||||
help:
|
||||
en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process."
|
||||
zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。"
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
47
api/core/model_runtime/model_providers/gitee_ai/llm/llm.py
Normal file
47
api/core/model_runtime/model_providers/gitee_ai/llm/llm.py
Normal file
@ -0,0 +1,47 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_TO_IDENTITY: dict[str, str] = {
|
||||
"Yi-1.5-34B-Chat": "Yi-34B-Chat",
|
||||
"deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct",
|
||||
"deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct",
|
||||
}
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None:
|
||||
if model is None:
|
||||
model = "bge-large-zh-v1.5"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
@ -0,0 +1 @@
|
||||
- bge-reranker-v2-m3
|
||||
@ -0,0 +1,4 @@
|
||||
model: bge-reranker-v2-m3
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
128
api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
Normal file
128
api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
Normal file
@ -0,0 +1,128 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class GiteeAIRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless")
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
body = {"model": model, "query": query, "documents": docs}
|
||||
if top_n is not None:
|
||||
body["top_n"] = top_n
|
||||
response = httpx.post(
|
||||
f"{base_url}/{model}/rerank",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["results"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result["index"],
|
||||
text=result["document"]["text"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.01,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,2 @@
|
||||
- whisper-base
|
||||
- whisper-large
|
||||
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from typing import IO, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
|
||||
class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Compatible Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text
|
||||
|
||||
endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text"
|
||||
files = [("file", file)]
|
||||
_, file_ext = os.path.splitext(file.name)
|
||||
headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
response = requests.post(endpoint_url, headers=headers, files=files)
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
response_data = response.json()
|
||||
return response_data["text"]
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
self._invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -0,0 +1,5 @@
|
||||
model: whisper-base
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||
@ -0,0 +1,5 @@
|
||||
model: whisper-large
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 1
|
||||
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||
@ -0,0 +1,3 @@
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
- bge-m3
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-large-zh-v1.5
|
||||
label:
|
||||
zh_Hans: bge-large-zh-v1.5
|
||||
en_US: bge-large-zh-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-m3
|
||||
label:
|
||||
zh_Hans: bge-m3
|
||||
en_US: bge-m3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,8 @@
|
||||
model: bge-small-zh-v1.5
|
||||
label:
|
||||
zh_Hans: bge-small-zh-v1.5
|
||||
en_US: bge-small-zh-v1.5
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 200000
|
||||
max_chunks: 20
|
||||
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str) -> None:
|
||||
if model is None:
|
||||
model = "bge-m3"
|
||||
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"
|
||||
@ -0,0 +1,11 @@
|
||||
model: ChatTTS
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,11 @@
|
||||
model: FunAudioLLM-CosyVoice-300M
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,4 @@
|
||||
- speecht5_tts
|
||||
- ChatTTS
|
||||
- fish-speech-1.2-sft
|
||||
- FunAudioLLM-CosyVoice-300M
|
||||
@ -0,0 +1,11 @@
|
||||
model: fish-speech-1.2-sft
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@ -0,0 +1,11 @@
|
||||
model: speecht5_tts
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'default'
|
||||
voices:
|
||||
- mode: 'default'
|
||||
name: 'Default'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
79
api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Normal file
79
api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
|
||||
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text="Hello Dify!",
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
# doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech
|
||||
endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload = {"inputs": content_text}
|
||||
response = requests.post(endpoint_url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
|
||||
data = response.content
|
||||
|
||||
for i in range(0, len(data), 1024):
|
||||
yield data[i : i + 1024]
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
450
api/core/model_runtime/model_providers/google/llm/llm.py
Normal file
450
api/core/model_runtime/model_providers/google/llm/llm.py
Normal file
@ -0,0 +1,450 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
properties = {}
|
||||
for key, value in tool.parameters.get("properties", {}).items():
|
||||
properties[key] = {
|
||||
"type_": glm.Type.STRING,
|
||||
"description": value.get("description", ""),
|
||||
"enum": value.get("enum", []),
|
||||
}
|
||||
|
||||
if properties:
|
||||
parameters = glm.Schema(
|
||||
type=glm.Type.OBJECT,
|
||||
properties=properties,
|
||||
required=tool.parameters.get("required", []),
|
||||
)
|
||||
else:
|
||||
parameters = None
|
||||
|
||||
function_declaration = glm.FunctionDeclaration(
|
||||
name=tool.name,
|
||||
parameters=parameters,
|
||||
description=tool.description,
|
||||
)
|
||||
function_declarations.append(function_declaration)
|
||||
|
||||
return glm.Tool(function_declarations=function_declarations)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
stream=stream,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||
request_options={"timeout": 600},
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
for chunk in response:
|
||||
for part in chunk.parts:
|
||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
||||
|
||||
if part.text:
|
||||
assistant_prompt_message.content += part.text
|
||||
|
||||
if part.function_call:
|
||||
assistant_prompt_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.function_call.name,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps(dict(part.function_call.args.items())),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
index += 1
|
||||
|
||||
if not response._done:
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk.candidates[0].finish_reason),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = {"role": "user", "parts": []}
|
||||
if isinstance(message.content, str):
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
else:
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
glm_content = {"role": "model", "parts": []}
|
||||
if message.content:
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
if message.tool_calls:
|
||||
glm_content["parts"].append(
|
||||
to_part(
|
||||
glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
)
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
"parts": [
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name, response={"response": message.content}
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [exceptions.RetryError],
|
||||
InvokeServerUnavailableError: [
|
||||
exceptions.ServiceUnavailable,
|
||||
exceptions.InternalServerError,
|
||||
exceptions.BadGateway,
|
||||
exceptions.GatewayTimeout,
|
||||
exceptions.DeadlineExceeded,
|
||||
],
|
||||
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
|
||||
InvokeAuthorizationError: [
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.PermissionDenied,
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.Forbidden,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
exceptions.BadRequest,
|
||||
exceptions.InvalidArgument,
|
||||
exceptions.FailedPrecondition,
|
||||
exceptions.OutOfRange,
|
||||
exceptions.NotFound,
|
||||
exceptions.MethodNotAllowed,
|
||||
exceptions.Conflict,
|
||||
exceptions.AlreadyExists,
|
||||
exceptions.Aborted,
|
||||
exceptions.LengthRequired,
|
||||
exceptions.PreconditionFailed,
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
],
|
||||
}
|
||||
330
api/core/model_runtime/model_providers/moonshot/llm/llm.py
Normal file
330
api/core/model_runtime/model_providers/moonshot/llm/llm.py
Normal file
@ -0,0 +1,330 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "tool_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials["mode"] = "chat"
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
|
||||
model_schema.features or []
|
||||
):
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append(
|
||||
{
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"]
|
||||
if response_tool_call.get("function", {}).get("name")
|
||||
else "",
|
||||
arguments=response_tool_call["function"]["arguments"]
|
||||
if response_tool_call.get("function", {}).get("arguments")
|
||||
else "",
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function,
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
||||
)
|
||||
@ -0,0 +1,847 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
"""
|
||||
Model class for OpenAI large language model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# text completion model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model:
|
||||
:param credentials:
|
||||
:param prompt_messages:
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials using requests to ensure compatibility with all providers following
|
||||
OpenAI's API standard.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {"model": model, "max_tokens": 5}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
data["messages"] = [
|
||||
{"role": "user", "content": "ping"},
|
||||
]
|
||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data["prompt"] = "ping"
|
||||
endpoint_url = urljoin(endpoint_url, "completions")
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
||||
|
||||
if completion_type is LLMMode.CHAT and json_result.get("object", "") == "":
|
||||
json_result["object"] = "chat.completion"
|
||||
elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "":
|
||||
json_result["object"] = "text_completion"
|
||||
|
||||
if completion_type is LLMMode.CHAT and (
|
||||
"object" not in json_result or json_result["object"] != "chat.completion"
|
||||
):
|
||||
raise CredentialsValidateFailedError(
|
||||
"Credentials validation failed: invalid response object, must be 'chat.completion'"
|
||||
)
|
||||
elif completion_type is LLMMode.COMPLETION and (
|
||||
"object" not in json_result or json_result["object"] != "text_completion"
|
||||
):
|
||||
raise CredentialsValidateFailedError(
|
||||
"Credentials validation failed: invalid response object, must be 'text_completion'"
|
||||
)
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
features = []
|
||||
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "function_call":
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
elif function_calling_type == "tool_call":
|
||||
features.append(ModelFeature.MULTI_TOOL_CALL)
|
||||
|
||||
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
||||
if stream_function_calling == "supported":
|
||||
features.append(ModelFeature.STREAM_TOOL_CALL)
|
||||
|
||||
vision_support = credentials.get("vision_support", "not_support")
|
||||
if vision_support == "support":
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
help=I18nObject(
|
||||
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
|
||||
"The higher the value, the stronger the randomness."
|
||||
"The higher the possibility of getting different answers to the same question.",
|
||||
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("temperature", 0.7)),
|
||||
min=0,
|
||||
max=2,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
help=I18nObject(
|
||||
en_US="The probability threshold of the nucleus sampling method during the generation process."
|
||||
"The larger the value is, the higher the randomness of generation will be."
|
||||
"The smaller the value is, the higher the certainty of generation will be.",
|
||||
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("top_p", 1)),
|
||||
min=0,
|
||||
max=1,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="For controlling the repetition rate of words used by the model."
|
||||
"Increasing this can reduce the repetition of the same words in the model's output.",
|
||||
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("frequency_penalty", 0)),
|
||||
min=-2,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="Used to control the repetition rate when generating models."
|
||||
"Increasing this can reduce the repetition rate of model generation.",
|
||||
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("presence_penalty", 0)),
|
||||
min=-2,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
help=I18nObject(
|
||||
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
|
||||
),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
if credentials["mode"] == "chat":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials["mode"] == "completion":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
||||
|
||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
|
||||
# following OpenAI's API standard.
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept-Charset": "utf-8",
|
||||
}
|
||||
extra_headers = credentials.get("extra_headers")
|
||||
if extra_headers is not None:
|
||||
headers = {
|
||||
**headers,
|
||||
**extra_headers,
|
||||
}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
data = {"model": model, "stream": stream, **model_parameters}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
endpoint_url = urljoin(endpoint_url, "completions")
|
||||
data["prompt"] = prompt_messages[0].content
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
if function_calling_type == "function_call":
|
||||
data["functions"] = [
|
||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
|
||||
for tool in tools
|
||||
]
|
||||
elif function_calling_type == "tool_call":
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
data["stop"] = stop
|
||||
|
||||
if user:
|
||||
data["user"] = user
|
||||
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
|
||||
|
||||
if response.encoding is None or response.encoding == "ISO-8859-1":
|
||||
response.encoding = "utf-8"
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
elif (
|
||||
"function_call" in delta
|
||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
||||
):
|
||||
assistant_message_tool_calls = [
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
# reset tool calls
|
||||
tool_calls = []
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
response_json: dict = response.json()
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
output = response_json["choices"][0]
|
||||
message_id = response_json.get("id")
|
||||
|
||||
response_content = ""
|
||||
tool_calls = None
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if completion_type is LLMMode.CHAT:
|
||||
response_content = output.get("message", {})["content"]
|
||||
if function_calling_type == "tool_call":
|
||||
tool_calls = output.get("message", {}).get("tool_calls")
|
||||
elif function_calling_type == "function_call":
|
||||
tool_calls = output.get("message", {}).get("function_call")
|
||||
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
response_content = output["text"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
||||
|
||||
if tool_calls:
|
||||
if function_calling_type == "tool_call":
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
elif function_calling_type == "function_call":
|
||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
||||
|
||||
usage = response_json.get("usage")
|
||||
if usage:
|
||||
# transform usage
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
id=message_id,
|
||||
model=response_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "tool_call":
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
|
||||
elif function_calling_type == "function_call":
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
if function_calling_type == "tool_call":
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif function_calling_type == "function_call":
|
||||
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name and message_dict.get("role", "") != "tool":
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens for model with gpt2 tokenizer.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
full_text = text
|
||||
else:
|
||||
full_text = ""
|
||||
for message_content in text:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
full_text += message_content.data
|
||||
|
||||
num_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
credentials: Optional[dict] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens with GPT2 tokenizer.
|
||||
"""
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling with tiktoken package.
|
||||
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += self._get_num_tokens_by_gpt2("name")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
||||
num_tokens += self._get_num_tokens_by_gpt2("description")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
||||
if "title" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("title")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
||||
if "properties" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
||||
if "required" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.get("function", {}).get("name", ""),
|
||||
arguments=response_tool_call.get("function", {}).get("arguments", ""),
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "")
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.get("id", ""), type="function", function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
@ -0,0 +1,55 @@
|
||||
model: claude-3-5-sonnet-v2@20241022
|
||||
label:
|
||||
en_US: Claude 3.5 Sonnet v2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
Reference in New Issue
Block a user