chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@ -18,12 +18,21 @@ class Callback:
Base class for callbacks.
Only for LLM.
"""
raise_error: bool = False
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Before invoke callback
@ -39,10 +48,19 @@ class Callback:
"""
raise NotImplementedError()
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
"""
On new chunk callback
@ -59,10 +77,19 @@ class Callback:
"""
raise NotImplementedError()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
After invoke callback
@ -79,10 +106,19 @@ class Callback:
"""
raise NotImplementedError()
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Invoke error callback
@ -99,9 +135,7 @@ class Callback:
"""
raise NotImplementedError()
def print_text(
self, text: str, color: Optional[str] = None, end: str = ""
) -> None:
def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)

View File

@ -10,11 +10,20 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Before invoke callback
@ -28,40 +37,49 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
self.print_text(f"Model: {model}\n", color='blue')
self.print_text("Parameters:\n", color='blue')
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
self.print_text(f"Model: {model}\n", color="blue")
self.print_text("Parameters:\n", color="blue")
for key, value in model_parameters.items():
self.print_text(f"\t{key}: {value}\n", color='blue')
self.print_text(f"\t{key}: {value}\n", color="blue")
if stop:
self.print_text(f"\tstop: {stop}\n", color='blue')
self.print_text(f"\tstop: {stop}\n", color="blue")
if tools:
self.print_text("\tTools:\n", color='blue')
self.print_text("\tTools:\n", color="blue")
for tool in tools:
self.print_text(f"\t\t{tool.name}\n", color='blue')
self.print_text(f"\t\t{tool.name}\n", color="blue")
self.print_text(f"Stream: {stream}\n", color='blue')
self.print_text(f"Stream: {stream}\n", color="blue")
if user:
self.print_text(f"User: {user}\n", color='blue')
self.print_text(f"User: {user}\n", color="blue")
self.print_text("Prompt messages:\n", color='blue')
self.print_text("Prompt messages:\n", color="blue")
for prompt_message in prompt_messages:
if prompt_message.name:
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
if stream:
self.print_text("\n[on_llm_new_chunk]")
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
"""
On new chunk callback
@ -79,10 +97,19 @@ class LoggingCallback(Callback):
sys.stdout.write(chunk.delta.message.content)
sys.stdout.flush()
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
After invoke callback
@ -97,24 +124,33 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
self.print_text(f"Content: {result.message.content}\n", color='yellow')
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
self.print_text(f"Content: {result.message.content}\n", color="yellow")
if result.message.tool_calls:
self.print_text("Tool calls:\n", color='yellow')
self.print_text("Tool calls:\n", color="yellow")
for tool_call in result.message.tool_calls:
self.print_text(f"\t{tool_call.id}\n", color='yellow')
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
self.print_text(f"\t{tool_call.id}\n", color="yellow")
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
self.print_text(f"Model: {result.model}\n", color='yellow')
self.print_text(f"Usage: {result.usage}\n", color='yellow')
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
self.print_text(f"Model: {result.model}\n", color="yellow")
self.print_text(f"Usage: {result.usage}\n", color="yellow")
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
"""
Invoke error callback
@ -129,5 +165,5 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
self.print_text("\n[on_llm_invoke_error]\n", color='red')
self.print_text("\n[on_llm_invoke_error]\n", color="red")
logger.exception(ex)

View File

@ -7,6 +7,7 @@ class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
en_US: str

View File

@ -2,123 +2,123 @@ from core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
'label': {
'en_US': 'Temperature',
'zh_Hans': '温度',
"label": {
"en_US": "Temperature",
"zh_Hans": "温度",
},
'type': 'float',
'help': {
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_P: {
'label': {
'en_US': 'Top P',
'zh_Hans': 'Top P',
"label": {
"en_US": "Top P",
"zh_Hans": "Top P",
},
'type': 'float',
'help': {
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
'zh_Hans': '通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。',
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
"zh_Hans": "通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。",
},
'required': False,
'default': 1.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 1.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_K: {
'label': {
'en_US': 'Top K',
'zh_Hans': 'Top K',
"label": {
"en_US": "Top K",
"zh_Hans": "Top K",
},
'type': 'int',
'help': {
'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.',
'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。',
"type": "int",
"help": {
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
},
'required': False,
'default': 50,
'min': 1,
'max': 100,
'precision': 0,
"required": False,
"default": 50,
"min": 1,
"max": 100,
"precision": 0,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
"label": {
"en_US": "Presence Penalty",
"zh_Hans": "存在惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
'label': {
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
"label": {
"en_US": "Frequency Penalty",
"zh_Hans": "频率惩罚",
},
'type': 'float',
'help': {
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
},
'required': False,
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 2,
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.MAX_TOKENS: {
'label': {
'en_US': 'Max Tokens',
'zh_Hans': '最大标记',
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大标记",
},
'type': 'int',
'help': {
'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.',
'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。',
"type": "int",
"help": {
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
},
'required': False,
'default': 64,
'min': 1,
'max': 2048,
'precision': 0,
"required": False,
"default": 64,
"min": 1,
"max": 2048,
"precision": 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
"label": {
"en_US": "Response Format",
"zh_Hans": "回复格式",
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等',
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等",
},
'required': False,
'options': ['JSON', 'XML'],
"required": False,
"options": ["JSON", "XML"],
},
DefaultParameterName.JSON_SCHEMA: {
'label': {
'en_US': 'JSON Schema',
"label": {
"en_US": "JSON Schema",
},
'type': 'text',
'help': {
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
'zh_Hans': '设置返回的json schemallm将按照它返回',
"type": "text",
"help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的json schemallm将按照它返回",
},
'required': False,
"required": False,
},
}

View File

@ -12,11 +12,12 @@ class LLMMode(Enum):
"""
Enum class for large language model mode.
"""
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> 'LLMMode':
def value_of(cls, value: str) -> "LLMMode":
"""
Get value of given mode.
@ -26,13 +27,14 @@ class LLMMode(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int
prompt_unit_price: Decimal
prompt_price_unit: Decimal
@ -50,20 +52,20 @@ class LLMUsage(ModelUsage):
def empty_usage(cls):
return cls(
prompt_tokens=0,
prompt_unit_price=Decimal('0.0'),
prompt_price_unit=Decimal('0.0'),
prompt_price=Decimal('0.0'),
prompt_unit_price=Decimal("0.0"),
prompt_price_unit=Decimal("0.0"),
prompt_price=Decimal("0.0"),
completion_tokens=0,
completion_unit_price=Decimal('0.0'),
completion_price_unit=Decimal('0.0'),
completion_price=Decimal('0.0'),
completion_unit_price=Decimal("0.0"),
completion_price_unit=Decimal("0.0"),
completion_price=Decimal("0.0"),
total_tokens=0,
total_price=Decimal('0.0'),
currency='USD',
latency=0.0
total_price=Decimal("0.0"),
currency="USD",
latency=0.0,
)
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
def plus(self, other: "LLMUsage") -> "LLMUsage":
"""
Add two LLMUsage instances together.
@ -85,10 +87,10 @@ class LLMUsage(ModelUsage):
total_tokens=self.total_tokens + other.total_tokens,
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency
latency=self.latency + other.latency,
)
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
def __add__(self, other: "LLMUsage") -> "LLMUsage":
"""
Overload the + operator to add two LLMUsage instances.
@ -97,10 +99,12 @@ class LLMUsage(ModelUsage):
"""
return self.plus(other)
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
model: str
prompt_messages: list[PromptMessage]
message: AssistantPromptMessage
@ -112,6 +116,7 @@ class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage
usage: Optional[LLMUsage] = None
@ -122,6 +127,7 @@ class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str
prompt_messages: list[PromptMessage]
system_fingerprint: Optional[str] = None
@ -132,4 +138,5 @@ class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
"""
tokens: int

View File

@ -9,13 +9,14 @@ class PromptMessageRole(Enum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@classmethod
def value_of(cls, value: str) -> 'PromptMessageRole':
def value_of(cls, value: str) -> "PromptMessageRole":
"""
Get value of given mode.
@ -25,13 +26,14 @@ class PromptMessageRole(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt message type value {value}')
raise ValueError(f"invalid prompt message type value {value}")
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel):
"""
Model class for prompt message function.
"""
type: str = 'function'
type: str = "function"
function: PromptMessageTool
@ -49,14 +52,16 @@ class PromptMessageContentType(Enum):
"""
Enum class for prompt message content type.
"""
TEXT = 'text'
IMAGE = 'image'
TEXT = "text"
IMAGE = "image"
class PromptMessageContent(BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
data: str
@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
LOW = 'low'
HIGH = 'high'
LOW = "low"
HIGH = "high"
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
name: Optional[str] = None
@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str
arguments: str
@ -123,7 +135,7 @@ class AssistantPromptMessage(PromptMessage):
type: str
function: ToolCallFunction
@field_validator('id', mode='before')
@field_validator("id", mode="before")
@classmethod
def transform_id_to_str(cls, value) -> str:
if not isinstance(value, str):
@ -145,10 +157,12 @@ class AssistantPromptMessage(PromptMessage):
return True
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str

View File

@ -11,6 +11,7 @@ class ModelType(Enum):
"""
Enum class for model type.
"""
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
@ -26,22 +27,22 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
return cls.LLM
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
return cls.RERANK
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
raise ValueError(f'invalid origin model type {origin_model_type}')
raise ValueError(f"invalid origin model type {origin_model_type}")
def to_origin_model_type(self) -> str:
"""
@ -50,26 +51,28 @@ class ModelType(Enum):
:return: origin model type
"""
if self == self.LLM:
return 'text-generation'
return "text-generation"
elif self == self.TEXT_EMBEDDING:
return 'embeddings'
return "embeddings"
elif self == self.RERANK:
return 'reranking'
return "reranking"
elif self == self.SPEECH2TEXT:
return 'speech2text'
return "speech2text"
elif self == self.TTS:
return 'tts'
return "tts"
elif self == self.MODERATION:
return 'moderation'
return "moderation"
elif self == self.TEXT2IMG:
return 'text2img'
return "text2img"
else:
raise ValueError(f'invalid model type {self}')
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
"""
Enum class for fetch from.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
@ -78,6 +81,7 @@ class ModelFeature(Enum):
"""
Enum class for llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
@ -89,6 +93,7 @@ class DefaultParameterName(str, Enum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
@ -99,7 +104,7 @@ class DefaultParameterName(str, Enum):
JSON_SCHEMA = "json_schema"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
def value_of(cls, value: Any) -> "DefaultParameterName":
"""
Get parameter name from value.
@ -109,13 +114,14 @@ class DefaultParameterName(str, Enum):
for name in cls:
if name.value == value:
return name
raise ValueError(f'invalid parameter name {value}')
raise ValueError(f"invalid parameter name {value}")
class ParameterType(Enum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
@ -127,6 +133,7 @@ class ModelPropertyKey(Enum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
@ -144,6 +151,7 @@ class ProviderModel(BaseModel):
"""
Model class for provider model.
"""
model: str
label: I18nObject
model_type: ModelType
@ -158,6 +166,7 @@ class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""
name: str
use_template: Optional[str] = None
label: I18nObject
@ -175,6 +184,7 @@ class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""
input: Decimal
output: Optional[Decimal] = None
unit: Decimal
@ -185,6 +195,7 @@ class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None
@ -197,6 +208,7 @@ class PriceType(Enum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
@ -205,6 +217,7 @@ class PriceInfo(BaseModel):
"""
Model class for price info.
"""
unit_price: Decimal
unit: Decimal
total_amount: Decimal

View File

@ -12,6 +12,7 @@ class ConfigurateMethod(Enum):
"""
Enum class for configurate method of provider model.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
@ -20,6 +21,7 @@ class FormType(Enum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel):
"""
Model class for form show on.
"""
variable: str
value: str
@ -39,6 +42,7 @@ class FormOption(BaseModel):
"""
Model class for form option.
"""
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
@ -46,15 +50,14 @@ class FormOption(BaseModel):
def __init__(self, **data):
super().__init__(**data)
if not self.label:
self.label = I18nObject(
en_US=self.value
)
self.label = I18nObject(en_US=self.value)
class CredentialFormSchema(BaseModel):
"""
Model class for credential form schema.
"""
variable: str
label: I18nObject
type: FormType
@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel):
"""
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel):
"""
Model class for model credential schema.
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel):
"""
Model class for provider help.
"""
title: I18nObject
url: I18nObject
@ -110,6 +117,7 @@ class ProviderEntity(BaseModel):
"""
Model class for provider.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
@ -138,7 +146,7 @@ class ProviderEntity(BaseModel):
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models
models=self.models,
)
@ -146,5 +154,6 @@ class ProviderConfig(BaseModel):
"""
Model class for provider config.
"""
provider: str
credentials: dict

View File

@ -5,6 +5,7 @@ class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int
text: str
score: float
@ -14,5 +15,6 @@ class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str
docs: list[RerankDocument]

View File

@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int
total_tokens: int
unit_price: Decimal
@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -3,6 +3,7 @@ from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
@ -14,24 +15,29 @@ class InvokeError(Exception):
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""
description = "Connection Error"
class InvokeServerUnavailableError(InvokeError):
"""Raised when the Invoke returns server unavailable error."""
description = "Server Unavailable Error"
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"
class InvokeAuthorizationError(InvokeError):
"""Raised when the Invoke returns authorization error."""
description = "Incorrect model credentials provided, please check and try again. "
class InvokeBadRequestError(InvokeError):
"""Raised when the Invoke returns bad request."""
description = "Bad Request Error"

View File

@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception):
"""
Credentials validate failed error
"""
pass

View File

@ -66,12 +66,14 @@ class AIModel(ABC):
:param error: model invoke error
:return: unified error
"""
provider_name = self.__class__.__module__.split('.')[-3]
provider_name = self.__class__.__module__.split(".")[-3]
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
return invoke_error(
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
)
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
@ -115,7 +117,7 @@ class AIModel(ABC):
if not price_config:
raise ValueError(f"Price config not found for model {model}")
total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
return PriceInfo(
unit_price=unit_price,
@ -136,24 +138,26 @@ class AIModel(ABC):
model_schemas = []
# get module name
model_type = self.__class__.__module__.split('.')[-1]
model_type = self.__class__.__module__.split(".")[-1]
# get provider name
provider_name = self.__class__.__module__.split('.')[-3]
provider_name = self.__class__.__module__.split(".")[-3]
# get the path of current classes
current_path = os.path.abspath(__file__)
# get parent path of the current path
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
provider_model_type_path = os.path.join(
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
)
# get all yaml files path under provider_model_type_path that do not start with __
model_schema_yaml_paths = [
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
if not model_schema_yaml.startswith("__")
and not model_schema_yaml.startswith("_")
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
and model_schema_yaml.endswith(".yaml")
]
# get _position.yaml file path
@ -165,10 +169,10 @@ class AIModel(ABC):
yaml_data = load_yaml_file(model_schema_yaml_path)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
if 'use_template' in parameter_rule:
for parameter_rule in yaml_data.get("parameter_rules", []):
if "use_template" in parameter_rule:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"])
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
copy_default_parameter_rule = default_parameter_rule.copy()
copy_default_parameter_rule.update(parameter_rule)
@ -176,31 +180,26 @@ class AIModel(ABC):
except ValueError:
pass
if 'label' not in parameter_rule:
parameter_rule['label'] = {
'zh_Hans': parameter_rule['name'],
'en_US': parameter_rule['name']
}
if "label" not in parameter_rule:
parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]}
new_parameter_rules.append(parameter_rule)
yaml_data['parameter_rules'] = new_parameter_rules
yaml_data["parameter_rules"] = new_parameter_rules
if 'label' not in yaml_data:
yaml_data['label'] = {
'zh_Hans': yaml_data['model'],
'en_US': yaml_data['model']
}
if "label" not in yaml_data:
yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]}
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
try:
# yaml_data to entity
model_schema = AIModelEntity(**yaml_data)
except Exception as e:
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
f' {str(e)}')
raise Exception(
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}"
)
# cache model schema
model_schemas.append(model_schema)
@ -235,7 +234,9 @@ class AIModel(ABC):
return None
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(
self, model: str, credentials: Mapping
) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@ -261,19 +262,19 @@ class AIModel(ABC):
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help and 'help' in default_parameter_rule:
if not parameter_rule.max and "max" in default_parameter_rule:
parameter_rule.max = default_parameter_rule["max"]
if not parameter_rule.min and "min" in default_parameter_rule:
parameter_rule.min = default_parameter_rule["min"]
if not parameter_rule.default and "default" in default_parameter_rule:
parameter_rule.default = default_parameter_rule["default"]
if not parameter_rule.precision and "precision" in default_parameter_rule:
parameter_rule.precision = default_parameter_rule["precision"]
if not parameter_rule.required and "required" in default_parameter_rule:
parameter_rule.required = default_parameter_rule["required"]
if not parameter_rule.help and "help" in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
en_US=default_parameter_rule["help"]["en_US"],
)
if (
parameter_rule.help

View File

@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel):
"""
Model class for large language model.
"""
model_type: ModelType = ModelType.LLM
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]:
def invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -69,7 +77,7 @@ class LargeLanguageModel(AIModel):
callbacks = callbacks or []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
@ -82,7 +90,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
try:
@ -96,7 +104,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
@ -111,7 +119,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
raise self._transform_invoke_error(e)
@ -127,7 +135,7 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
elif isinstance(result, LLMResult):
self._trigger_after_invoke_callbacks(
@ -140,15 +148,23 @@ class LargeLanguageModel(AIModel):
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
return result
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
@ -183,7 +199,7 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
model_parameters.pop("response_format")
@ -195,15 +211,16 @@ if you are not sure about the structure.
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", str(prompt_messages[0].content))
content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content))
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
prompt_messages.insert(
0,
SystemPromptMessage(
content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.")
),
)
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last text message
@ -216,9 +233,7 @@ if you are not sure about the structure.
break
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
@ -228,33 +243,30 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
)
else:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
def _code_block_mode_stream_processor(
self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
@ -303,16 +315,13 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
),
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
def _code_block_mode_stream_processor_with_backtick(
self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
@ -378,18 +387,23 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
),
)
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
def _invoke_result_generator(
self,
model: str,
result: Generator,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator:
"""
Invoke result generator
@ -397,9 +411,7 @@ if you are not sure about the structure.
:return: result generator
"""
callbacks = callbacks or []
prompt_message = AssistantPromptMessage(
content=""
)
prompt_message = AssistantPromptMessage(content="")
usage = None
system_fingerprint = None
real_model = model
@ -418,7 +430,7 @@ if you are not sure about the structure.
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
prompt_message.content += chunk.delta.message.content
@ -438,7 +450,7 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint
system_fingerprint=system_fingerprint,
),
credentials=credentials,
prompt_messages=prompt_messages,
@ -447,15 +459,21 @@ if you are not sure about the structure.
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
callbacks=callbacks,
)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -472,8 +490,13 @@ if you are not sure about the structure.
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -519,7 +542,9 @@ if you are not sure about the structure.
return mode
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
def _calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
) -> LLMUsage:
"""
Calculate response usage
@ -539,10 +564,7 @@ if you are not sure about the structure.
# get completion price info
completion_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
)
# transform usage
@ -558,16 +580,23 @@ if you are not sure about the structure.
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
def _trigger_before_invoke_callbacks(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Trigger before invoke callbacks
@ -593,7 +622,7 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
@ -601,11 +630,19 @@ if you are not sure about the structure.
else:
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
def _trigger_new_chunk_callbacks(
self,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Trigger new chunk callbacks
@ -632,7 +669,7 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
@ -640,11 +677,19 @@ if you are not sure about the structure.
else:
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
def _trigger_after_invoke_callbacks(
self,
model: str,
result: LLMResult,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Trigger after invoke callbacks
@ -672,7 +717,7 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
@ -680,11 +725,19 @@ if you are not sure about the structure.
else:
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
def _trigger_invoke_error_callbacks(
self,
model: str,
ex: Exception,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Trigger invoke error callbacks
@ -712,7 +765,7 @@ if you are not sure about the structure.
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
except Exception as e:
if callback.raise_error:
@ -758,11 +811,13 @@ if you are not sure about the structure.
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, float | int):
raise ValueError(f"Model Parameter {parameter_name} should be float.")
@ -775,16 +830,19 @@ if you are not sure about the structure.
else:
if parameter_value != round(parameter_value, parameter_rule.precision):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
)
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.BOOLEAN:
if not isinstance(parameter_value, bool):
raise ValueError(f"Model Parameter {parameter_name} should be bool.")

View File

@ -29,32 +29,32 @@ class ModelProvider(ABC):
def get_provider_schema(self) -> ProviderEntity:
"""
Get provider schema
:return: provider schema
"""
if self.provider_schema:
return self.provider_schema
# get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
provider_name = self.__class__.__module__.split(".")[-1]
# get the path of the model_provider classes
base_path = os.path.abspath(__file__)
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
yaml_data = load_yaml_file(yaml_path)
try:
# yaml_data to entity
provider_schema = ProviderEntity(**yaml_data)
except Exception as e:
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
# cache schema
self.provider_schema = provider_schema
return provider_schema
def models(self, model_type: ModelType) -> list[AIModelEntity]:
@ -92,15 +92,15 @@ class ModelProvider(ABC):
# get the path of the model type classes
base_path = os.path.abspath(__file__)
model_type_name = model_type.value.replace('-', '_')
model_type_name = model_type.value.replace("-", "_")
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
raise Exception(f"Invalid model type {model_type} for provider {provider_name}")
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
mod = import_module_from_source(
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)

View File

@ -12,14 +12,13 @@ class ModerationModel(AIModel):
"""
Model class for moderation model.
"""
model_type: ModelType = ModelType.MODERATION
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
"""
Invoke moderation model
@ -37,9 +36,7 @@ class ModerationModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
-> bool:
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
"""
Invoke large language model
@ -50,4 +47,3 @@ class ModerationModel(AIModel):
:return: false if text is safe, true otherwise
"""
raise NotImplementedError

View File

@ -11,12 +11,19 @@ class RerankModel(AIModel):
"""
Base Model class for rerank model.
"""
model_type: ModelType = ModelType.RERANK
def invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -37,10 +44,16 @@ class RerankModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model

View File

@ -12,14 +12,13 @@ class Speech2TextModel(AIModel):
"""
Model class for speech2text model.
"""
model_type: ModelType = ModelType.SPEECH2TEXT
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
@ -35,9 +34,7 @@ class Speech2TextModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke large language model
@ -59,4 +56,4 @@ class Speech2TextModel(AIModel):
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the audio file
return os.path.join(current_dir, 'audio.mp3')
return os.path.join(current_dir, "audio.mp3")

View File

@ -11,14 +11,15 @@ class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
def invoke(
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
) -> list[IO[bytes]]:
"""
Invoke Text2Image model
@ -36,9 +37,9 @@ class Text2ImageModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
def _invoke(
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
) -> list[IO[bytes]]:
"""
Invoke Text2Image model

View File

@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel):
"""
Model class for text embedding model.
"""
model_type: ModelType = ModelType.TEXT_EMBEDDING
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke large language model
@ -38,9 +39,9 @@ class TextEmbeddingModel(AIModel):
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke large language model

View File

@ -7,27 +7,28 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
_tokenizer = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer
return _tokenizer

View File

@ -15,13 +15,15 @@ class TTSModel(AIModel):
"""
Model class for TTS model.
"""
model_type: ModelType = ModelType.TTS
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
def invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
"""
Invoke large language model
@ -35,14 +37,21 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
return self._invoke(model=model, credentials=credentials, user=user,
content_text=content_text, voice=voice, tenant_id=tenant_id)
return self._invoke(
model=model,
credentials=credentials,
user=user,
content_text=content_text,
voice=voice,
tenant_id=tenant_id,
)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
"""
Invoke large language model
@ -71,10 +80,13 @@ class TTSModel(AIModel):
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [{'name': d['name'], 'value': d['mode']} for d in voices if
language and language in d.get('language')]
return [
{"name": d["name"], "value": d["mode"]}
for d in voices
if language and language in d.get("language")
]
else:
return [{'name': d['name'], 'value': d['mode']} for d in voices]
return [{"name": d["name"], "value": d["mode"]} for d in voices]
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
"""
@ -123,23 +135,23 @@ class TTSModel(AIModel):
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
match = re.compile(pattern)
tx = match.finditer(org_text)
start = 0
result = []
one_sentence = ''
one_sentence = ""
for i in tx:
end = i.regs[0][1]
tmp = org_text[start:end]
if len(one_sentence + tmp) > max_length:
result.append(one_sentence)
one_sentence = ''
one_sentence = ""
one_sentence += tmp
start = end
last_sens = org_text[start:]
if last_sens:
one_sentence += last_sens
if one_sentence != '':
if one_sentence != "":
result.append(one_sentence)
return result

View File

@ -20,12 +20,9 @@ class AnthropicProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `claude-3-opus-20240229` model for validate,
model_instance.validate_credentials(
model='claude-3-opus-20240229',
credentials=credentials
)
model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -55,11 +55,17 @@ if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -76,10 +82,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# invoke model
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -96,41 +109,39 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
if "max_tokens_to_sample" in model_parameters:
model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample")
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
extra_model_kwargs["stop_sequences"] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
extra_model_kwargs["system"] = system
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
if model_parameters.get('max_tokens') > 4096:
if model_parameters.get("max_tokens") > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if tools:
extra_model_kwargs['tools'] = [
self._transform_tool_prompt(tool) for tool in tools
]
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
else:
# chat model
@ -140,22 +151,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if model_parameters.get('response_format'):
if model_parameters.get("response_format"):
stop = stop or []
# chat model
self._transform_chat_json_prompts(
@ -167,24 +186,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
response_format=model_parameters["response_format"],
)
model_parameters.pop('response_format')
model_parameters.pop("response_format")
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
return {
'name': tool.name,
'description': tool.description,
'input_schema': tool.parameters
}
return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters}
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
def _transform_chat_json_prompts(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
response_format: str = "JSON",
) -> None:
"""
Transform json prompts
"""
@ -197,22 +219,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
"{{block}}", response_format
)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.insert(
0,
SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
"{{instructions}}", f"Please output a valid {response_format} object."
).replace("{{block}}", response_format)
),
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -228,9 +258,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
tokens = client.count_tokens(prompt)
tool_call_inner_prompts_tokens_map = {
'claude-3-opus-20240229': 395,
'claude-3-haiku-20240307': 264,
'claude-3-sonnet-20240229': 159
"claude-3-opus-20240229": 395,
"claude-3-haiku-20240307": 264,
"claude-3-sonnet-20240229": 159,
}
if model in tool_call_inner_prompts_tokens_map and tools:
@ -257,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"temperature": 0,
"max_tokens": 20,
},
stream=False
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: Union[Message, ToolsBetaMessage],
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm chat response
@ -274,22 +309,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[])
for content in response.content:
if content.type == 'text':
if content.type == "text":
assistant_prompt_message.content += content.text
elif content.type == 'tool_use':
elif content.type == "tool_use":
tool_call = AssistantPromptMessage.ToolCall(
id=content.id,
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content.name,
arguments=json.dumps(content.input)
)
name=content.name, arguments=json.dumps(content.input)
),
)
assistant_prompt_message.tool_calls.append(tool_call)
@ -308,17 +339,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# transform response
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(
self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm chat stream response
@ -327,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
full_assistant_content = ""
return_model = None
input_tokens = 0
output_tokens = 0
@ -338,24 +366,23 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
for chunk in response:
if isinstance(chunk, MessageStartEvent):
if hasattr(chunk, 'content_block'):
if hasattr(chunk, "content_block"):
content_block = chunk.content_block
if isinstance(content_block, dict):
if content_block.get('type') == 'tool_use':
if content_block.get("type") == "tool_use":
tool_call = AssistantPromptMessage.ToolCall(
id=content_block.get('id'),
type='function',
id=content_block.get("id"),
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content_block.get('name'),
arguments=''
)
name=content_block.get("name"), arguments=""
),
)
tool_calls.append(tool_call)
elif hasattr(chunk, 'delta'):
elif hasattr(chunk, "delta"):
delta = chunk.delta
if isinstance(delta, dict) and len(tool_calls) > 0:
if delta.get('type') == 'input_json_delta':
tool_calls[-1].function.arguments += delta.get('partial_json', '')
if delta.get("type") == "input_json_delta":
tool_calls[-1].function.arguments += delta.get("partial_json", "")
elif chunk.message:
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
@ -369,29 +396,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# transform empty tool call arguments to {}
for tool_call in tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = '{}'
tool_call.function.arguments = "{}"
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
content='',
tool_calls=tool_calls
),
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
chunk_text = chunk.delta.text if chunk.delta.text else ""
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
index = chunk.index
@ -401,7 +423,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=chunk.index,
message=assistant_prompt_message,
)
),
)
def _to_credential_kwargs(self, credentials: dict) -> dict:
@ -412,14 +434,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return:
"""
credentials_kwargs = {
"api_key": credentials['anthropic_api_key'],
"api_key": credentials["anthropic_api_key"],
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if credentials.get('anthropic_api_url'):
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
if credentials.get("anthropic_api_url"):
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/")
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
return credentials_kwargs
@ -452,10 +474,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
@ -465,25 +484,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
)
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages})
@ -492,34 +511,28 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
content = []
if message.tool_calls:
for tool_call in message.tool_calls:
content.append({
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments)
})
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments),
}
)
if message.content:
content.append({
"type": "text",
"text": message.content
})
content.append({"type": "text", "text": message.content})
if prompt_message_dicts[-1]["role"] == "assistant":
prompt_message_dicts[-1]["content"].extend(content)
else:
prompt_message_dicts.append({
"role": "assistant",
"content": content
})
prompt_message_dicts.append({"role": "assistant", "content": content})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
"content": [
{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}
],
}
prompt_message_dicts.append(message_dict)
else:
@ -576,16 +589,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message) for message in messages)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
@ -601,24 +611,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
anthropic.APIConnectionError,
anthropic.APITimeoutError
],
InvokeServerUnavailableError: [
anthropic.InternalServerError
],
InvokeRateLimitError: [
anthropic.RateLimitError
],
InvokeAuthorizationError: [
anthropic.AuthenticationError,
anthropic.PermissionDeniedError
],
InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError],
InvokeServerUnavailableError: [anthropic.InternalServerError],
InvokeRateLimitError: [anthropic.RateLimitError],
InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError],
InvokeBadRequestError: [
anthropic.BadRequestError,
anthropic.NotFoundError,
anthropic.UnprocessableEntityError,
anthropic.APIError
]
anthropic.APIError,
],
}

View File

@ -15,10 +15,10 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN
class _CommonAzureOpenAI:
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION)
credentials_kwargs = {
"api_key": credentials['openai_api_key'],
"azure_endpoint": credentials['openai_api_base'],
"api_key": credentials["openai_api_key"],
"azure_endpoint": credentials["openai_api_base"],
"api_version": api_version,
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
@ -29,24 +29,14 @@ class _CommonAzureOpenAI:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
openai.APIConnectionError,
openai.APITimeoutError
],
InvokeServerUnavailableError: [
openai.InternalServerError
],
InvokeRateLimitError: [
openai.RateLimitError
],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError
],
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError
]
openai.APIError,
],
}

View File

@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
class AzureOpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -34,16 +34,20 @@ logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
base_model_name = credentials.get('base_model_name')
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError('Base Model Name is required')
raise ValueError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
@ -56,7 +60,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
else:
# text completion model
@ -67,7 +71,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(
@ -75,14 +79,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = credentials.get('base_model_name')
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError('Base Model Name is required')
raise ValueError("Base Model Name is required")
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f'Base Model Name {base_model_name} is invalid')
raise ValueError(f"Base Model Name {base_model_name} is invalid")
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
@ -92,21 +96,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# text completion model, do not support tool calling
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials,content)
return self._num_tokens_from_string(credentials, content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = credentials.get('base_model_name')
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise CredentialsValidateFailedError('Base Model Name is required')
raise CredentialsValidateFailedError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
@ -118,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": 'ping'}],
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
@ -127,7 +131,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else:
# text completion model
client.completions.create(
prompt='ping',
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
@ -137,33 +141,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = credentials.get('base_model_name')
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError('Base Model Name is required')
raise ValueError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content,
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
)
if stream:
@ -172,15 +178,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
@ -209,24 +212,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ''
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ''):
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text if delta.text else ''
assistant_prompt_message = AssistantPromptMessage(
content=text
)
text = delta.text if delta.text else ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
@ -254,8 +254,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
)
usage=usage,
),
)
else:
yield LLMResultChunk(
@ -265,14 +265,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
)
),
)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
@ -293,7 +299,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
@ -301,10 +307,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# } for tool in tools]
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
# chat model
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
@ -322,9 +328,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self, model: str, credentials: dict, response: ChatCompletion,
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
tools: Optional[list[PromptMessageTool]] = None,
):
assistant_message = response.choices[0].message
assistant_message_tool_calls = assistant_message.tool_calls
@ -334,10 +343,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
# calculate num tokens
if response.usage:
@ -369,13 +375,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
tools: Optional[list[PromptMessageTool]] = None,
):
index = 0
full_assistant_content = ''
full_assistant_content = ""
real_model = model
system_fingerprint = None
completion = ''
completion = ""
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
@ -386,7 +392,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if delta.delta is None:
continue
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
@ -396,15 +401,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=tool_calls
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
full_assistant_content += delta.delta.content if delta.delta.content else ''
full_assistant_content += delta.delta.content if delta.delta.content else ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else ''
completion += delta.delta.content if delta.delta.content else ""
yield LLMResultChunk(
model=real_model,
@ -413,7 +417,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 0
@ -421,9 +425,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(
content=completion
)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
@ -434,27 +436,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
finish_reason='stop',
usage=usage
)
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
),
)
@staticmethod
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
def _update_tool_calls(
tool_calls: list[AssistantPromptMessage.ToolCall],
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
@ -463,8 +462,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
tool_calls[index].function.name = (
response_tool_call.function.name or tool_calls[index].function.name
)
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
@ -473,13 +474,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
assert response_tool_call.function.arguments is not None
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
@ -495,19 +493,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
@ -525,7 +517,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
@ -535,10 +527,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(self, credentials: dict, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_string(
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials['base_model_name'])
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
@ -550,14 +543,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials['base_model_name']
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -591,10 +583,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -626,40 +618,39 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode('function'))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters['title']))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters['type']))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters['properties'].items():
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))

View File

@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
@ -40,7 +38,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -65,10 +63,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:

View File

@ -16,19 +16,18 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
base_model_name = credentials['base_model_name']
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
base_model_name = credentials["base_model_name"]
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
extra_model_kwargs = {}
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
extra_model_kwargs['encoding_format'] = 'base64'
extra_model_kwargs["encoding_format"] = "base64"
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
@ -44,11 +43,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
enc = tiktoken.get_encoding("cl100k_base")
for i, text in enumerate(texts):
token = enc.encode(
text
)
token = enc.encode(text)
for j in range(0, len(token), context_size):
tokens += [token[j: j + context_size]]
tokens += [token[j : j + context_size]]
indices += [i]
batched_embeddings = []
@ -56,10 +53,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
for i in _iter:
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=tokens[i: i + max_chunks],
extra_model_kwargs=extra_model_kwargs
model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs
)
used_tokens += embedding_used_tokens
@ -75,10 +69,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
_result = results[i]
if len(_result) == 0:
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts="",
extra_model_kwargs=extra_model_kwargs
model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs
)
used_tokens += embedding_used_tokens
@ -88,24 +79,16 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
embeddings[i] = (average / np.linalg.norm(average)).tolist()
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
return TextEmbeddingResult(
embeddings=embeddings,
usage=usage,
model=base_model_name
)
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
if len(texts) == 0:
return 0
try:
enc = tiktoken.encoding_for_model(credentials['base_model_name'])
enc = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
enc = tiktoken.get_encoding("cl100k_base")
@ -118,57 +101,52 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
return total_num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
if not self._get_ai_model_entity(credentials['base_model_name'], model):
if not self._get_ai_model_entity(credentials["base_model_name"], model):
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
self._embedding_invoke(
model=model,
client=client,
texts=['ping'],
extra_model_kwargs={}
)
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
return ai_model_entity.entity
@staticmethod
def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
def _embedding_invoke(
model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
) -> tuple[list[list[float]], int]:
response = client.embeddings.create(
input=texts,
model=model,
**extra_model_kwargs,
)
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64":
# decode base64 embedding
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
response.usage.total_tokens)
return (
[list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
response.usage.total_tokens,
)
return [data.embedding for data in response.data], response.usage.total_tokens
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -179,7 +157,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, user: Optional[str] = None) -> any:
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> any:
"""
_invoke text2speech model
@ -30,13 +31,12 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:param user: unique user id
:return: text translated to audio file
"""
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
if not voice or voice not in [
d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials)
]:
voice = self._get_model_default_voice(model, credentials)
return self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice)
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -50,14 +50,13 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text='Hello Dify!',
content_text="Hello Dify!",
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any:
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
@ -75,23 +74,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
futures = [
executor.submit(
client.audio.speech.with_streaming_response.create,
model=model,
response_format="mp3",
input=sentences[i],
voice=voice,
)
for i in range(len(sentences))
]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip()
)
yield from response.__enter__().iter_bytes(1024)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _process_sentence(self, sentence: str, model: str,
voice, credentials: dict):
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
@ -108,10 +113,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
return response.read()
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
for ai_model_entity in TTS_BASE_MODELS:

View File

@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class BaichuanProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -19,12 +20,9 @@ class BaichuanProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `baichuan2-turbo` model for validate,
model_instance.validate_credentials(
model='baichuan2-turbo',
credentials=credentials
)
model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -4,17 +4,17 @@ import re
class BaichuanTokenizer:
@classmethod
def count_chinese_characters(cls, text: str) -> int:
return len(re.findall(r'[\u4e00-\u9fa5]', text))
return len(re.findall(r"[\u4e00-\u9fa5]", text))
@classmethod
def count_english_vocabularies(cls, text: str) -> int:
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
# count the number of words not characters
return len(text.split())
@classmethod
def _get_num_tokens(cls, text: str) -> int:
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
# https://platform.baichuan-ai.com/docs/text-Embedding
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)

View File

@ -94,7 +94,6 @@ class BaichuanModel:
timeout: int,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[Iterator, dict]:
if model in self._model_mapping.keys():
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else:
@ -120,9 +119,7 @@ class BaichuanModel:
err = resp["error"]["type"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)

View File

@ -1,17 +1,22 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass
pass

View File

@ -38,17 +38,16 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
class BaichuanLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
@ -60,17 +59,17 @@ class BaichuanLanguageModel(LargeLanguageModel):
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
self,
messages: list[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
@ -111,18 +110,13 @@ class BaichuanLanguageModel(LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
message.tool_calls]
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id
}
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
else:
raise ValueError(f"Unknown message type {type(message)}")
@ -146,15 +140,14 @@ class BaichuanLanguageModel(LargeLanguageModel):
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stream: bool = True,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stream: bool = True,
) -> LLMResult | Generator:
instance = BaichuanModel(api_key=credentials["api_key"])
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
@ -169,23 +162,19 @@ class BaichuanLanguageModel(LargeLanguageModel):
)
if stream:
return self._handle_chat_generate_stream_response(
model, prompt_messages, credentials, response
)
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_response(
model, prompt_messages, credentials, response
)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: dict,
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: dict,
) -> LLMResult:
choices = response.get("choices", [])
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if choices and choices[0]["finish_reason"] == "tool_calls":
for choice in choices:
for tool_call in choice["message"]["tool_calls"]:
@ -194,7 +183,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
type=tool_call.get("type", ""),
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.get("function", {}).get("name", ""),
arguments=tool_call.get("function", {}).get("arguments", "")
arguments=tool_call.get("function", {}).get("arguments", ""),
),
)
assistant_message.tool_calls.append(tool)
@ -228,11 +217,11 @@ class BaichuanLanguageModel(LargeLanguageModel):
)
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Iterator,
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Iterator,
) -> Generator:
for line in response:
if not line:
@ -260,9 +249,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=choice["delta"]["content"], tool_calls=[]
),
message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]),
finish_reason=stop_reason,
),
)

View File

@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for BaiChuan text embedding model.
"""
api_base: str = 'http://api.baichuan-ai.com/v1/embeddings'
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -45,28 +46,23 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
if model != 'baichuan-text-embedding':
raise ValueError('Invalid model name')
api_key = credentials["api_key"]
if model != "baichuan-text-embedding":
raise ValueError("Invalid model name")
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
raise CredentialsValidateFailedError("api_key is required")
# split into chunks of batch size 16
chunks = []
for i in range(0, len(texts), 16):
chunks.append(texts[i:i + 16])
chunks.append(texts[i : i + 16])
embeddings = []
token_usage = 0
for chunk in chunks:
# embedding chunk
chunk_embeddings, chunk_usage = self.embedding(
model=model,
api_key=api_key,
texts=chunk,
user=user
)
chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user)
embeddings.extend(chunk_embeddings)
token_usage += chunk_usage
@ -74,17 +70,14 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
)
return result
def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \
-> tuple[list[list[float]], int]:
def embedding(
self, model: str, api_key, texts: list[str], user: Optional[str] = None
) -> tuple[list[list[float]], int]:
"""
Embed given texts
@ -95,56 +88,47 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings result
"""
url = self.api_base
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
data = {
'model': 'Baichuan-Text-Embedding',
'input': texts
}
data = {"model": "Baichuan-Text-Embedding", "input": texts}
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
# try to parse error message
err = resp['error']['code']
msg = resp['error']['message']
err = resp["error"]["code"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
if err == 'invalid_api_key':
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == 'insufficient_quota':
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
raise InvalidAuthenticationError(msg)
elif err and 'rate' in err:
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err and "rate" in err:
raise RateLimitReachedError(msg)
elif err and 'internal' in err:
elif err and "internal" in err:
raise InternalServerError(msg)
elif err == 'api_key_empty':
elif err == "api_key_empty":
raise InvalidAPIKeyError(msg)
else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
return [
data['embedding'] for data in embeddings
], usage['total_tokens']
return [data["embedding"] for data in embeddings], usage["total_tokens"]
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
@ -170,32 +154,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvalidAPIKeyError:
raise CredentialsValidateFailedError('Invalid api key')
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -207,10 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -221,7 +194,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class BedrockProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -19,13 +20,10 @@ class BedrockProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `amazon.titan-text-lite-v1` model by default for validating credentials
model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1')
model_instance.validate_credentials(
model=model_for_validation,
credentials=credentials
)
model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1")
model_instance.validate_credentials(model=model_for_validation, credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -45,36 +45,42 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
class BedrockLargeLanguageModel(LargeLanguageModel):
class BedrockLargeLanguageModel(LargeLanguageModel):
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
CONVERSE_API_ENABLED_MODEL_INFO=[
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
CONVERSE_API_ENABLED_MODEL_INFO = [
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
]
@staticmethod
def _find_model_info(model_id):
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
if model_id.startswith(model['prefix']):
if model_id.startswith(model["prefix"]):
return model
logger.info(f"current model id: {model_id} did not support by Converse API")
return None
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -88,17 +94,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
model_info= BedrockLargeLanguageModel._find_model_info(model)
model_info = BedrockLargeLanguageModel._find_model_info(model)
if model_info:
model_info['model'] = model
model_info["model"] = model
# invoke models via boto3 converse API
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
return self._generate_with_converse(
model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools
)
# invoke other models via boto3 client
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
def _generate_with_converse(
self,
model_info: dict,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model with converse API
@ -110,35 +127,39 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(service_name='bedrock-runtime',
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"])
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
'modelId': model_info['model'],
'messages': prompt_message_dicts,
'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields,
"modelId": model_info["model"],
"messages": prompt_message_dicts,
"inferenceConfig": inference_config,
"additionalModelRequestFields": additional_model_fields,
}
if model_info['support_system_prompts'] and system and len(system) > 0:
parameters['system'] = system
if model_info["support_system_prompts"] and system and len(system) > 0:
parameters["system"] = system
if model_info['support_tool_use'] and tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if model_info["support_tool_use"] and tools:
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
try:
if stream:
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
return self._handle_converse_stream_response(
model_info["model"], credentials, response, prompt_messages
)
else:
response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages)
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
@ -149,8 +170,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise InvokeError(str(ex))
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_converse_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm chat response
@ -160,36 +183,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: full response chunk generator result
"""
response_content = response['output']['message']['content']
response_content = response["output"]["message"]["content"]
# transform assistant message to prompt message
if response['stopReason'] == 'tool_use':
if response["stopReason"] == "tool_use":
tool_calls = []
text, tool_use = self._extract_tool_use(response_content)
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
id=tool_use["toolUseId"],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=json.dumps(tool_use['input'])
)
name=tool_use["name"], arguments=json.dumps(tool_use["input"])
),
)
tool_calls.append(tool_call)
assistant_prompt_message = AssistantPromptMessage(
content=text,
tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls)
else:
assistant_prompt_message = AssistantPromptMessage(
content=response_content[0]['text']
)
assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"])
# calculate num tokens
if response['usage']:
if response["usage"]:
# transform usage
prompt_tokens = response['usage']['inputTokens']
completion_tokens = response['usage']['outputTokens']
prompt_tokens = response["usage"]["inputTokens"]
completion_tokens = response["usage"]["outputTokens"]
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -206,20 +223,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
)
return result
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
def _extract_tool_use(self, content: dict) -> tuple[str, dict]:
tool_use = {}
text = ''
text = ""
for item in content:
if 'toolUse' in item:
tool_use = item['toolUse']
elif 'text' in item:
text = item['text']
if "toolUse" in item:
tool_use = item["toolUse"]
elif "text" in item:
text = item["text"]
else:
raise ValueError(f"Got unknown item: {item}")
return text, tool_use
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage], ) -> Generator:
def _handle_converse_stream_response(
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
@ -231,7 +253,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
try:
full_assistant_content = ''
full_assistant_content = ""
return_model = None
input_tokens = 0
output_tokens = 0
@ -240,87 +262,85 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_use = {}
for chunk in response['stream']:
if 'messageStart' in chunk:
for chunk in response["stream"]:
if "messageStart" in chunk:
return_model = model
elif 'messageStop' in chunk:
finish_reason = chunk['messageStop']['stopReason']
elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']
elif 'metadata' in chunk:
input_tokens = chunk['metadata']['usage']['inputTokens']
output_tokens = chunk['metadata']['usage']['outputTokens']
elif "messageStop" in chunk:
finish_reason = chunk["messageStop"]["stopReason"]
elif "contentBlockStart" in chunk:
tool = chunk["contentBlockStart"]["start"]["toolUse"]
tool_use["toolUseId"] = tool["toolUseId"]
tool_use["name"] = tool["name"]
elif "metadata" in chunk:
input_tokens = chunk["metadata"]["usage"]["inputTokens"]
output_tokens = chunk["metadata"]["usage"]["outputTokens"]
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content='',
tool_calls=tool_calls
),
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
elif 'contentBlockDelta' in chunk:
delta = chunk['contentBlockDelta']['delta']
if 'text' in delta:
chunk_text = delta['text'] if delta['text'] else ''
elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "text" in delta:
chunk_text = delta["text"] if delta["text"] else ""
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
content=chunk_text if chunk_text else "",
)
index = chunk['contentBlockDelta']['contentBlockIndex']
index = chunk["contentBlockDelta"]["contentBlockIndex"]
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index+1,
index=index + 1,
message=assistant_prompt_message,
)
),
)
elif 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
elif "toolUse" in delta:
if "input" not in tool_use:
tool_use["input"] = ""
tool_use["input"] += delta["toolUse"]["input"]
elif "contentBlockStop" in chunk:
if "input" in tool_use:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
id=tool_use["toolUseId"],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=tool_use['input']
)
name=tool_use["name"], arguments=tool_use["input"]
),
)
tool_calls.append(tool_call)
tool_use = {}
except Exception as ex:
raise InvokeError(str(ex))
def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
def _convert_converse_api_model_parameters(
self, model_parameters: dict, stop: Optional[list[str]] = None
) -> tuple[dict, dict]:
inference_config = {}
additional_model_fields = {}
if 'max_tokens' in model_parameters:
inference_config['maxTokens'] = model_parameters['max_tokens']
if "max_tokens" in model_parameters:
inference_config["maxTokens"] = model_parameters["max_tokens"]
if 'temperature' in model_parameters:
inference_config['temperature'] = model_parameters['temperature']
if 'top_p' in model_parameters:
inference_config['topP'] = model_parameters['temperature']
if "temperature" in model_parameters:
inference_config["temperature"] = model_parameters["temperature"]
if "top_p" in model_parameters:
inference_config["topP"] = model_parameters["temperature"]
if stop:
inference_config['stopSequences'] = stop
if 'top_k' in model_parameters:
additional_model_fields['top_k'] = model_parameters['top_k']
inference_config["stopSequences"] = stop
if "top_k" in model_parameters:
additional_model_fields["top_k"] = model_parameters["top_k"]
return inference_config, additional_model_fields
def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
@ -332,7 +352,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
message.content = message.content.strip()
system.append({"text": message.content})
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
@ -349,15 +369,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {
"json": tool.parameters
}
"inputSchema": {"json": tool.parameters},
}
}
)
tool_config["tools"] = configs
return tool_config
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
@ -365,15 +383,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": [{'text': message.content}]}
message_dict = {"role": "user", "content": [{"text": message.content}]}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"text": message_content.data
}
sub_message_dict = {"text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
@ -384,7 +400,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
image_content = requests.get(url).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
@ -394,16 +410,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
)
sub_message_dict = {
"image": {
"format": mime_type.replace('image/', ''),
"source": {
"bytes": image_content
}
}
"image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}}
}
sub_messages.append(sub_message_dict)
@ -412,36 +425,46 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
if message.tool_calls:
message_dict = {
"role": "assistant", "content":[{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments)
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments),
}
}
}]
],
}
else:
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
message_dict = {"role": "assistant", "content": [{"text": message.content}]}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = [{'text': message.content}]
message_dict = [{"text": message.content}]
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}]
}
}]
"content": [
{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}],
}
}
],
}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -451,15 +474,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
prefix = model.split(".")[0]
model_name = model.split(".")[1]
if isinstance(prompt_messages, str):
prompt = prompt_messages
else:
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -482,24 +504,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"topP": 0.9,
"maxTokens": 32,
}
try:
ping_message = UserPromptMessage(content="ping")
self._invoke(model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters=required_params,
stream=False)
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters=required_params,
stream=False,
)
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
def _convert_one_message_to_text(
self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Convert a single message to a string.
@ -514,7 +540,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if isinstance(message, UserPromptMessage):
body = content
if (isinstance(content, list)):
if isinstance(content, list):
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
elif isinstance(message, AssistantPromptMessage):
@ -528,7 +554,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
def _convert_messages_to_prompt(
self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
@ -537,27 +565,31 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(
self._convert_one_message_to_text(message, model_prefix, model_name)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
def _create_payload(
self,
model: str,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
"""
payload = {}
model_prefix = model.split('.')[0]
model_name = model.split('.')[1]
model_prefix = model.split(".")[0]
model_name = model.split(".")[1]
if model_prefix == "ai21":
payload["temperature"] = model_parameters.get("temperature")
@ -571,21 +603,27 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
if model_parameters.get("countPenalty"):
payload["countPenalty"] = {model_parameters.get("countPenalty")}
elif model_prefix == "cohere":
payload = { **model_parameters }
payload = {**model_parameters}
payload["prompt"] = prompt_messages[0].content
payload["stream"] = stream
else:
raise ValueError(f"Got unknown model prefix {model_prefix}")
return payload
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -598,18 +636,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
client_config = Config(
region_name=credentials["aws_region"]
)
client_config = Config(region_name=credentials["aws_region"])
runtime_client = boto3.client(
service_name='bedrock-runtime',
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key")
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
# need workaround for ai21 models which doesn't support streaming
@ -619,18 +655,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
invoke = runtime_client.invoke_model
try:
body_jsonstr=json.dumps(payload)
response = invoke(
modelId=model,
contentType="application/json",
accept= "*/*",
body=body_jsonstr
)
body_jsonstr = json.dumps(payload)
response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr)
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))
@ -639,15 +670,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise InvokeError(str(ex))
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@ -657,7 +688,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
response_body = json.loads(response.get('body').read().decode('utf-8'))
response_body = json.loads(response.get("body").read().decode("utf-8"))
finish_reason = response_body.get("error")
@ -665,25 +696,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
raise InvokeError(finish_reason)
# get output text and calculate num tokens based on model / provider
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "ai21":
output = response_body.get('completions')[0].get('data').get('text')
output = response_body.get("completions")[0].get("data").get("text")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
# construct assistant message from output
assistant_prompt_message = AssistantPromptMessage(
content=output
)
assistant_prompt_message = AssistantPromptMessage(content=output)
# calculate usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -698,8 +727,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@ -709,65 +739,59 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "ai21":
response_body = json.loads(response.get('body').read().decode('utf-8'))
response_body = json.loads(response.get("body").read().decode("utf-8"))
content = response_body.get('completions')[0].get('data').get('text')
finish_reason = response_body.get('completions')[0].get('finish_reason')
content = response_body.get("completions")[0].get("data").get("text")
finish_reason = response_body.get("completions")[0].get("finish_reason")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content),
finish_reason=finish_reason,
usage=usage
)
)
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage
),
)
return
stream = response.get('body')
stream = response.get("body")
if not stream:
raise InvokeError('No response body')
raise InvokeError("No response body")
index = -1
for event in stream:
chunk = event.get('chunk')
chunk = event.get("chunk")
if not chunk:
exception_name = next(iter(event))
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
payload = json.loads(chunk.get('bytes').decode())
payload = json.loads(chunk.get("bytes").decode())
model_prefix = model.split('.')[0]
model_prefix = model.split(".")[0]
if model_prefix == "cohere":
content_delta = payload.get("text")
finish_reason = payload.get("finish_reason")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content = content_delta if content_delta else '',
content=content_delta if content_delta else "",
)
index += 1
if not finish_reason:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
@ -777,18 +801,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=finish_reason,
usage=usage
)
index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage
),
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -804,9 +825,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
InvokeBadRequestError: [],
}
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
@ -822,7 +843,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
return InvokeRateLimitError(error_msg)
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
elif error_code in [
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

View File

@ -27,12 +27,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
logger = logging.getLogger(__name__)
class BedrockTextEmbeddingModel(TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -42,67 +41,56 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
client_config = Config(
region_name=credentials["aws_region"]
)
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key")
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
embeddings = []
token_usage = 0
model_prefix = model.split('.')[0]
if model_prefix == "amazon" :
model_prefix = model.split(".")[0]
if model_prefix == "amazon":
for text in texts:
body = {
"inputText": text,
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
embeddings.extend([response_body.get("embedding")])
token_usage += response_body.get("inputTextTokenCount")
logger.warning(f"Total Tokens: {token_usage}")
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
)
return result
if model_prefix == "cohere" :
input_type = 'search_document' if len(texts) > 1 else 'search_query'
if model_prefix == "cohere":
input_type = "search_document" if len(texts) > 1 else "search_query"
for text in texts:
body = {
"texts": [text],
"input_type": input_type,
"texts": [text],
"input_type": input_type,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
embeddings.extend(response_body.get("embeddings"))
token_usage += len(text)
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
)
return result
#others
# others
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -125,7 +113,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -141,19 +129,25 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
InvokeBadRequestError: [],
}
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
def _create_payload(
self,
model_prefix: str,
texts: list[str],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
"""
payload = {}
if model_prefix == "amazon":
payload['inputText'] = texts
payload["inputText"] = texts
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -165,10 +159,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -179,7 +170,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -199,31 +190,37 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
return InvokeRateLimitError(error_msg)
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
elif error_code in [
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
]:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)
return InvokeError(error_msg)
def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
accept = 'application/json'
content_type = 'application/json'
def _invoke_bedrock_embedding(
self,
model: str,
bedrock_runtime,
body: dict,
):
accept = "application/json"
content_type = "application/json"
try:
response = bedrock_runtime.invoke_model(
body=json.dumps(body),
modelId=model,
accept=accept,
contentType=content_type
body=json.dumps(body), modelId=model, accept=accept, contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))
response_body = json.loads(response.get("body").read().decode("utf-8"))
return response_body
except ClientError as ex:
error_code = ex.response['Error']['Code']
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))

View File

@ -20,12 +20,9 @@ class ChatGLMProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `chatglm3-6b` model for validate,
model_instance.validate_credentials(
model='chatglm3-6b',
credentials=credentials
)
model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -43,12 +43,19 @@ from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class ChatGLMLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
@ -71,11 +78,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -96,11 +108,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, prompt_messages=[
UserPromptMessage(content="ping"),
], model_parameters={
"max_tokens": 16,
})
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping"),
],
model_parameters={
"max_tokens": 16,
},
)
except Exception as e:
raise CredentialsValidateFailedError(str(e))
@ -124,24 +141,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError
PermissionDeniedError,
],
InvokeRateLimitError: [
RateLimitError
],
InvokeAuthorizationError: [
AuthenticationError
],
InvokeBadRequestError: [
ValueError
]
InvokeRateLimitError: [RateLimitError],
InvokeAuthorizationError: [AuthenticationError],
InvokeBadRequestError: [ValueError],
}
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
@ -163,35 +180,31 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
if tools and len(tools) > 0:
extra_model_kwargs['functions'] = [
helper.dump_model(tool) for tool in tools
]
extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools]
result = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
)
return self._handle_chat_generate_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
)
def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None:
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
@ -212,7 +225,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
@ -223,12 +236,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _extract_response_tool_calls(self,
response_function_calls: list[FunctionCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -239,19 +252,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if response_function_calls:
for response_tool_call in response_function_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.name,
arguments=response_tool_call.arguments
name=response_tool_call.name, arguments=response_tool_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=0,
type='function',
function=function
)
tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function)
tool_calls.append(tool_call)
return tool_calls
def _to_client_kwargs(self, credentials: dict) -> dict:
"""
Convert invoke kwargs to client kwargs
@ -265,17 +273,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": str(URL(credentials['api_base']) / 'v1')
"base_url": str(URL(credentials["api_base"]) / "v1"),
}
return client_kwargs
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) \
-> Generator:
full_response = ''
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -283,9 +294,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
continue
# check if there is a tool call in the response
function_calls = None
if delta.delta.function_call:
@ -295,23 +306,25 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_message_tool_calls
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=assistant_message_tool_calls
content=full_response, tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -320,7 +333,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -335,11 +348,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
)
full_response += delta.delta.content
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) \
-> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
"""
Handle llm chat response
@ -359,15 +376,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
response = LLMResult(
model=model,
@ -378,7 +394,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
)
return response
def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
@ -395,17 +411,19 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
As a temporary solution we use GPT2 tokenizer instead.
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
@ -414,10 +432,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "function_call":
@ -452,36 +470,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return: number of tokens
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens('name')
num_tokens += tokens("name")
num_tokens += tokens(tool.name)
num_tokens += tokens('description')
num_tokens += tokens("description")
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens('parameters')
num_tokens += tokens('type')
num_tokens += tokens("parameters")
num_tokens += tokens("type")
num_tokens += tokens(parameters.get("type"))
if 'properties' in parameters:
num_tokens += tokens('properties')
for key, value in parameters.get('properties').items():
if "properties" in parameters:
num_tokens += tokens("properties")
for key, value in parameters.get("properties").items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if 'required' in parameters:
num_tokens += tokens('required')
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += tokens("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += tokens(required_field)

View File

@ -20,12 +20,9 @@ class CohereProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.RERANK)
# Use `rerank-english-v2.0` model for validate,
model_instance.validate_credentials(
model='rerank-english-v2.0',
credentials=credentials
)
model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
Model class for Cohere large language model.
"""
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -85,7 +91,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
else:
return self._generate(
@ -95,11 +101,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -136,30 +147,37 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content='ping')],
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'max_tokens': 20,
'temperature': 0,
"max_tokens": 20,
"temperature": 0,
},
stream=False
stream=False,
)
else:
self._generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content='ping')],
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'max_tokens': 20,
'temperature': 0,
"max_tokens": 20,
"temperature": 0,
},
stream=False
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm model
@ -173,17 +191,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
if stop:
model_parameters['end_sequences'] = stop
model_parameters["end_sequences"] = stop
if stream:
response = client.generate_stream(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
request_options=RequestOptions(max_retries=0),
)
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
@ -192,14 +210,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
request_options=RequestOptions(max_retries=0),
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@ -212,9 +230,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
assistant_text = response.generations[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
prompt_tokens = int(response.meta.billed_units.input_tokens)
@ -225,17 +241,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# transform response
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
)
return response
def _handle_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -245,7 +262,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: llm response chunk generator
"""
index = 1
full_assistant_content = ''
full_assistant_content = ""
for chunk in response:
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
@ -255,9 +272,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=text
)
assistant_prompt_message = AssistantPromptMessage(content=text)
full_assistant_content += text
@ -267,7 +282,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 1
@ -277,9 +292,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
completion_tokens = self._num_tokens_from_messages(
model,
credentials,
[AssistantPromptMessage(content=full_assistant_content)]
model, credentials, [AssistantPromptMessage(content=full_assistant_content)]
)
# transform usage
@ -290,20 +303,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
message=AssistantPromptMessage(content=""),
finish_reason=chunk.finish_reason,
usage=usage
)
usage=usage,
),
)
break
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
raise InvokeBadRequestError(chunk.err)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -318,27 +338,28 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
if stop:
model_parameters['stop_sequences'] = stop
model_parameters["stop_sequences"] = stop
if tools:
if len(tools) == 1:
raise ValueError("Cohere tool call requires at least two tools to be specified.")
model_parameters['tools'] = self._convert_tools(tools)
model_parameters["tools"] = self._convert_tools(tools)
message, chat_histories, tool_results \
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories(
prompt_messages
)
if tool_results:
model_parameters['tool_results'] = tool_results
model_parameters["tool_results"] = tool_results
# chat model
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
real_model = model.removesuffix("-chat")
if stream:
response = client.chat_stream(
@ -346,7 +367,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
request_options=RequestOptions(max_retries=0),
)
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
@ -356,14 +377,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
request_options=RequestOptions(max_retries=0),
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
def _handle_chat_generate_response(
self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm chat response
@ -380,19 +401,15 @@ class CohereLargeLanguageModel(LargeLanguageModel):
for cohere_tool_call in response.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters)
),
)
tool_calls.append(tool_call)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text,
tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls)
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
@ -403,17 +420,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# transform response
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[StreamedChatResponse],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Iterator[StreamedChatResponse],
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
@ -423,17 +441,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: llm response chunk generator
"""
def final_response(full_text: str,
tool_calls: list[AssistantPromptMessage.ToolCall],
index: int,
finish_reason: Optional[str] = None) -> LLMResultChunk:
def final_response(
full_text: str,
tool_calls: list[AssistantPromptMessage.ToolCall],
index: int,
finish_reason: Optional[str] = None,
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
full_assistant_prompt_message = AssistantPromptMessage(
content=full_text,
tool_calls=tool_calls
)
full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls)
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
# transform usage
@ -444,14 +461,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage
)
usage=usage,
),
)
index = 1
full_assistant_content = ''
full_assistant_content = ""
tool_calls = []
for chunk in response:
if isinstance(chunk, StreamedChatResponse_TextGeneration):
@ -462,9 +479,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=text
)
assistant_prompt_message = AssistantPromptMessage(content=text)
full_assistant_content += text
@ -474,7 +489,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
),
)
index += 1
@ -484,11 +499,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
for cohere_tool_call in chunk.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters)
),
)
tool_calls.append(tool_call)
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
@ -496,8 +510,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
def _convert_prompt_messages_to_message_and_chat_histories(
self, prompt_messages: list[PromptMessage]
) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
@ -510,13 +525,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
prompt_message = cast(AssistantPromptMessage, prompt_message)
if prompt_message.tool_calls:
for tool_call in prompt_message.tool_calls:
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call.function.name,
parameters=json.loads(tool_call.function.arguments)
),
outputs=[]
))
latest_tool_call_n_outputs.append(
ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments)
),
outputs=[],
)
)
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
@ -529,12 +545,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call_n_outputs.call.name,
parameters=tool_call_n_outputs.call.parameters
name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters
),
outputs=[{
"result": prompt_message.content
}]
outputs=[{"result": prompt_message.content}],
)
break
i += 1
@ -556,7 +569,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
latest_message = chat_histories.pop()
message = latest_message.message
else:
raise ValueError('Prompt messages is empty')
raise ValueError("Prompt messages is empty")
return message, chat_histories, latest_tool_call_n_outputs
@ -569,7 +582,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
chat_message = ChatMessage(role="USER", message=message.content)
else:
sub_message_text = ''
sub_message_text = ""
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
@ -597,8 +610,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
"""
cohere_tools = []
for tool in tools:
properties = tool.parameters['properties']
required_properties = tool.parameters['required']
properties = tool.parameters["properties"]
required_properties = tool.parameters["required"]
parameter_definitions = {}
for p_key, p_val in properties.items():
@ -606,21 +619,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if p_key in required_properties:
required = True
desc = p_val['description']
if 'enum' in p_val:
desc += (f"; Only accepts one of the following predefined options: "
f"[{', '.join(p_val['enum'])}]")
desc = p_val["description"]
if "enum" in p_val:
desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]"
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
description=desc,
type=p_val['type'],
required=required
description=desc, type=p_val["type"], required=required
)
cohere_tool = Tool(
name=tool.name,
description=tool.description,
parameter_definitions=parameter_definitions
name=tool.name, description=tool.description, parameter_definitions=parameter_definitions
)
cohere_tools.append(cohere_tool)
@ -637,12 +645,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: number of tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
response = client.tokenize(
text=text,
model=model
)
response = client.tokenize(text=text, model=model)
return len(response.tokens)
@ -658,30 +663,30 @@ class CohereLargeLanguageModel(LargeLanguageModel):
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
real_model = model.removesuffix("-chat")
return self._num_tokens_from_string(real_model, credentials, message_str)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Cohere supports fine-tuning of their models. This method returns the schema of the base model
but renamed to the fine-tuned model name.
Cohere supports fine-tuning of their models. This method returns the schema of the base model
but renamed to the fine-tuned model name.
:param model: model name
:param credentials: credentials
:param model: model name
:param credentials: credentials
:return: model schema
:return: model schema
"""
# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}
mode = credentials.get('mode')
mode = credentials.get("mode")
if mode == 'chat':
base_model_schema = model_map['command-light-chat']
if mode == "chat":
base_model_schema = model_map["command-light-chat"]
else:
base_model_schema = model_map['command-light']
base_model_schema = model_map["command-light"]
base_model_schema = cast(AIModelEntity, base_model_schema)
@ -691,16 +696,13 @@ class CohereLargeLanguageModel(LargeLanguageModel):
entity = AIModelEntity(
model=model,
label=I18nObject(
zh_Hans=model,
en_US=model
),
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.LLM,
features=list(base_model_schema_features),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=dict(base_model_schema_model_properties.items()),
parameter_rules=list(base_model_schema_parameters_rules),
pricing=base_model_schema.pricing
pricing=base_model_schema.pricing,
)
return entity
@ -716,22 +718,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
cohere.errors.forbidden_error.ForbiddenError,
],
InvokeBadRequestError: [
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
],
}

View File

@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel):
Model class for Cohere rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -38,20 +44,17 @@ class CohereRerankModel(RerankModel):
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(
model=model,
docs=docs
)
return RerankResult(model=model, docs=docs)
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
response = client.rerank(
query=query,
documents=docs,
model=model,
top_n=top_n,
return_documents=True,
request_options=RequestOptions(max_retries=0)
request_options=RequestOptions(max_retries=0),
)
rerank_documents = []
@ -70,10 +73,7 @@ class CohereRerankModel(RerankModel):
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
return RerankResult(model=model, docs=rerank_documents)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -94,7 +94,7 @@ class CohereRerankModel(RerankModel):
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -110,22 +110,16 @@ class CohereRerankModel(RerankModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
cohere.errors.forbidden_error.ForbiddenError,
],
InvokeBadRequestError: [
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
],
}

View File

@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
Model class for Cohere text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -46,14 +46,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0
for i, text in enumerate(texts):
tokenize_response = self._tokenize(
model=model,
credentials=credentials,
text=text
)
tokenize_response = self._tokenize(model=model, credentials=credentials, text=text)
for j in range(0, len(tokenize_response), context_size):
tokens += [tokenize_response[j: j + context_size]]
tokens += [tokenize_response[j : j + context_size]]
indices += [i]
batched_embeddings = []
@ -62,9 +58,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
for i in _iter:
# call embedding model
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=["".join(token) for token in tokens[i: i + max_chunks]]
model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]]
)
used_tokens += embedding_used_tokens
@ -80,9 +74,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
_result = results[i]
if len(_result) == 0:
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=[" "]
model=model, credentials=credentials, texts=[" "]
)
used_tokens += embedding_used_tokens
@ -92,17 +84,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
embeddings[i] = (average / np.linalg.norm(average)).tolist()
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
return TextEmbeddingResult(
embeddings=embeddings,
usage=usage,
model=model
)
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
@ -116,14 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
if len(texts) == 0:
return 0
full_text = ' '.join(texts)
full_text = " ".join(texts)
try:
response = self._tokenize(
model=model,
credentials=credentials,
text=full_text
)
response = self._tokenize(model=model, credentials=credentials, text=full_text)
except Exception as e:
raise self._transform_invoke_error(e)
@ -141,14 +121,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return []
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
response = client.tokenize(
text=text,
model=model,
offline=False,
request_options=RequestOptions(max_retries=0)
)
response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0))
return response.token_strings
@ -162,11 +137,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -180,14 +151,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings and used tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
# call embedding model
response = client.embed(
texts=texts,
model=model,
input_type='search_document' if len(texts) > 1 else 'search_query',
request_options=RequestOptions(max_retries=1)
input_type="search_document" if len(texts) > 1 else "search_query",
request_options=RequestOptions(max_retries=1),
)
return response.embeddings, int(response.meta.billed_units.input_tokens)
@ -203,10 +174,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -217,7 +185,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -233,22 +201,16 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
cohere.errors.forbidden_error.ForbiddenError,
],
InvokeBadRequestError: [
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
],
}

View File

@ -7,9 +7,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class DeepSeekProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -22,12 +20,9 @@ class DeepSeekProvider(ModelProvider):
# Use `deepseek-chat` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(
model='deepseek-chat',
credentials=credentials
)
model_instance.validate_credentials(model="deepseek-chat", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -13,12 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
@ -27,10 +32,8 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
@ -48,8 +51,9 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
return num_tokens
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
@ -69,10 +73,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -103,11 +107,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['openai_api_key']=credentials['api_key']
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['openai_api_base']='https://api.deepseek.com'
credentials["mode"] = "chat"
credentials["openai_api_key"] = credentials["api_key"]
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
credentials["openai_api_base"] = "https://api.deepseek.com"
else:
parsed_url = urlparse(credentials['endpoint_url'])
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
parsed_url = urlparse(credentials["endpoint_url"])
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"

View File

@ -1,4 +1,4 @@
import logging
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -18,11 +18,9 @@ class FishAudioProvider(ModelProvider):
"""
try:
model_instance = self.get_model_instance(ModelType.TTS)
model_instance.validate_credentials(
credentials=credentials
)
model_instance.validate_credentials(credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional
import httpx
@ -12,9 +12,7 @@ class FishAudioText2SpeechModel(TTSModel):
Model class for Fish.audio Text to Speech model.
"""
def get_tts_model_voices(
self, model: str, credentials: dict, language: Optional[str] = None
) -> list:
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
api_base = credentials.get("api_base", "https://api.fish.audio")
api_key = credentials.get("api_key")
use_public_models = credentials.get("use_public_models", "false") == "true"
@ -68,9 +66,7 @@ class FishAudioText2SpeechModel(TTSModel):
voice=voice,
)
def validate_credentials(
self, credentials: dict, user: Optional[str] = None
) -> None:
def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None:
"""
Validate credentials for text2speech model
@ -91,9 +87,7 @@ class FishAudioText2SpeechModel(TTSModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(
self, model: str, credentials: dict, content_text: str, voice: str
) -> any:
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
Invoke streaming text2speech model
:param model: model name
@ -106,12 +100,10 @@ class FishAudioText2SpeechModel(TTSModel):
try:
word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(
content_text, max_length=word_limit
)
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
else:
sentences = [content_text.strip()]
for i in range(len(sentences)):
yield from self._tts_invoke_streaming_sentence(
credentials=credentials, content_text=sentences[i], voice=voice
@ -120,9 +112,7 @@ class FishAudioText2SpeechModel(TTSModel):
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _tts_invoke_streaming_sentence(
self, credentials: dict, content_text: str, voice: Optional[str] = None
) -> any:
def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any:
"""
Invoke streaming text2speech model
@ -141,20 +131,14 @@ class FishAudioText2SpeechModel(TTSModel):
with httpx.stream(
"POST",
api_url + "/v1/tts",
json={
"text": content_text,
"reference_id": voice,
"latency": latency
},
json={"text": content_text, "reference_id": voice, "latency": latency},
headers={
"Authorization": f"Bearer {api_key}",
},
timeout=None,
) as response:
if response.status_code != 200:
raise InvokeBadRequestError(
f"Error: {response.status_code} - {response.text}"
)
raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}")
yield from response.iter_bytes()
@property

View File

@ -20,12 +20,9 @@ class GoogleProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `gemini-pro` model for validate,
model_instance.validate_credentials(
model='gemini-pro',
credentials=credentials
)
model_instance.validate_credentials(model="gemini-pro", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -49,12 +49,17 @@ if you are not sure about the structure.
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -70,9 +75,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -85,7 +95,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
@ -95,13 +105,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message) for message in messages)
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
"""
Convert tool messages to glm tools
@ -117,14 +124,16 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
type=glm.Type.OBJECT,
properties={
key: {
'type_': value.get('type', 'string').upper(),
'description': value.get('description', ''),
'enum': value.get('enum', [])
} for key, value in tool.parameters.get('properties', {}).items()
"type_": value.get("type", "string").upper(),
"description": value.get("description", ""),
"enum": value.get("enum", []),
}
for key, value in tool.parameters.get("properties", {}).items()
},
required=tool.parameters.get('required', [])
required=tool.parameters.get("required", []),
),
) for tool in tools
)
for tool in tools
]
)
@ -136,20 +145,25 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:param credentials: model credentials
:return:
"""
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -163,14 +177,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if stop:
config_kwargs["stop_sequences"] = stop
google_model = genai.GenerativeModel(
model_name=model
)
google_model = genai.GenerativeModel(model_name=model)
history = []
@ -180,7 +192,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
@ -194,7 +206,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
google_model._client = new_custom_client
safety_settings={
safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
@ -203,13 +215,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
),
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
request_options={"timeout": 600}
request_options={"timeout": 600},
)
if stream:
@ -217,8 +227,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@ -229,9 +240,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.text
)
assistant_prompt_message = AssistantPromptMessage(content=response.text)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -250,8 +259,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@ -264,9 +274,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
index = -1
for chunk in response:
for part in chunk.parts:
assistant_prompt_message = AssistantPromptMessage(
content=''
)
assistant_prompt_message = AssistantPromptMessage(content="")
if part.text:
assistant_prompt_message.content += part.text
@ -275,36 +283,31 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps(dict(part.function_call.args.items()))
)
arguments=json.dumps(dict(part.function_call.args.items())),
),
)
]
index += 1
if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
@ -312,8 +315,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk.candidates[0].finish_reason),
usage=usage
)
usage=usage,
),
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
@ -328,9 +331,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(
c.data for c in content if c.type != PromptMessageContentType.IMAGE
)
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@ -353,65 +354,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: glm Content representation of message
"""
if isinstance(message, UserPromptMessage):
glm_content = {
"role": "user",
"parts": []
}
if (isinstance(message.content, str)):
glm_content['parts'].append(to_part(message.content))
glm_content = {"role": "user", "parts": []}
if isinstance(message.content, str):
glm_content["parts"].append(to_part(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content['parts'].append(to_part(c.data))
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}}
glm_content['parts'].append(blob)
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {
"role": "model",
"parts": []
}
glm_content = {"role": "model", "parts": []}
if message.content:
glm_content['parts'].append(to_part(message.content))
glm_content["parts"].append(to_part(message.content))
if message.tool_calls:
glm_content["parts"].append(to_part(glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)))
glm_content["parts"].append(
to_part(
glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)
)
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {
"role": "user",
"parts": [to_part(message.content)]
}
return {"role": "user", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
"parts": [glm.Part(function_response=glm.FunctionResponse(
name=message.name,
response={
"response": message.content
}
))]
"parts": [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name, response={"response": message.content}
)
)
],
}
else:
raise ValueError(f"Got unknown type {message}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -423,25 +420,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [
exceptions.RetryError
],
InvokeConnectionError: [exceptions.RetryError],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded
],
InvokeRateLimitError: [
exceptions.ResourceExhausted,
exceptions.TooManyRequests
exceptions.DeadlineExceeded,
],
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden
exceptions.Forbidden,
],
InvokeBadRequestError: [
exceptions.BadRequest,
@ -457,5 +449,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
]
],
}

View File

@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class GroqProvider(ModelProvider):
class GroqProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -18,12 +18,9 @@ class GroqProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama3-8b-8192',
credentials=credentials
)
model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
@ -21,6 +27,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.groq.com/openai/v1"

View File

@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
class _CommonHuggingfaceHub:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError,
BadRequestError
]
}
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}

View File

@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
class HuggingfaceHubProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -29,16 +29,23 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
model = credentials["huggingfacehub_endpoint_url"]
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']
if 'baichuan' in model.lower():
if "baichuan" in model.lower():
stream = False
response = client.text_generation(
@ -47,98 +54,100 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
stream=stream,
model=model,
stop_sequences=stop,
**model_parameters)
**model_parameters,
)
if stream:
return self._handle_generate_stream_response(model, credentials, prompt_messages, response)
return self._handle_generate_response(model, credentials, prompt_messages, response)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
if 'huggingfacehub_api_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'):
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.')
if "huggingfacehub_api_token" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.")
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
if "huggingfacehub_endpoint_url" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.")
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'],
model)
if "task_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.")
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
credentials["task_type"] = self._get_hosted_model_task_type(
credentials["huggingfacehub_api_token"], model
)
if credentials['task_type'] not in ("text2text-generation", "text-generation"):
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, '
'text-generation.')
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, " "text-generation."
)
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
model = credentials["huggingfacehub_endpoint_url"]
try:
client.text_generation(
prompt='Who are you?',
stream=True,
model=model)
client.text_generation(prompt="Who are you?", stream=True, model=model)
except BadRequestError as e:
raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. '
'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.')
raise CredentialsValidateFailedError(
"Only available for models running on with the `text-generation-inference`. "
"To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference."
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
},
parameter_rules=self._get_customizable_model_parameter_rules()
model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value},
parameter_rules=self._get_customizable_model_parameter_rules(),
)
return entity
@staticmethod
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
DefaultParameterName.TEMPERATURE).copy()
temperature_rule_dict['name'] = 'temperature'
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy()
temperature_rule_dict["name"] = "temperature"
temperature_rule = ParameterRule(**temperature_rule_dict)
temperature_rule.default = 0.5
top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy()
top_p_rule_dict['name'] = 'top_p'
top_p_rule_dict["name"] = "top_p"
top_p_rule = ParameterRule(**top_p_rule_dict)
top_p_rule.default = 0.5
top_k_rule = ParameterRule(
name='top_k',
name="top_k",
label={
'en_US': 'Top K',
'zh_Hans': 'Top K',
"en_US": "Top K",
"zh_Hans": "Top K",
},
type='int',
type="int",
help={
'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.',
'zh_Hans': '保留的最高概率词汇标记的数量。',
"en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"zh_Hans": "保留的最高概率词汇标记的数量。",
},
required=False,
default=2,
@ -148,15 +157,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
)
max_new_tokens = ParameterRule(
name='max_new_tokens',
name="max_new_tokens",
label={
'en_US': 'Max New Tokens',
'zh_Hans': '最大新标记',
"en_US": "Max New Tokens",
"zh_Hans": "最大新标记",
},
type='int',
type="int",
help={
'en_US': 'Maximum number of generated tokens.',
'zh_Hans': '生成的标记的最大数量。',
"en_US": "Maximum number of generated tokens.",
"zh_Hans": "生成的标记的最大数量。",
},
required=False,
default=20,
@ -166,30 +175,30 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
)
seed = ParameterRule(
name='seed',
name="seed",
label={
'en_US': 'Random sampling seed',
'zh_Hans': '随机采样种子',
"en_US": "Random sampling seed",
"zh_Hans": "随机采样种子",
},
type='int',
type="int",
help={
'en_US': 'Random sampling seed.',
'zh_Hans': '随机采样种子。',
"en_US": "Random sampling seed.",
"zh_Hans": "随机采样种子。",
},
required=False,
precision=0,
)
repetition_penalty = ParameterRule(
name='repetition_penalty',
name="repetition_penalty",
label={
'en_US': 'Repetition Penalty',
'zh_Hans': '重复惩罚',
"en_US": "Repetition Penalty",
"zh_Hans": "重复惩罚",
},
type='float',
type="float",
help={
'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.',
'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。',
"en_US": "The parameter for repetition penalty. 1.0 means no penalty.",
"zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。",
},
required=False,
precision=1,
@ -197,11 +206,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty]
def _handle_generate_stream_response(self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
response: Generator) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator
) -> Generator:
index = -1
for chunk in response:
# skip special tokens
@ -210,9 +217,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
index += 1
assistant_prompt_message = AssistantPromptMessage(
content=chunk.token.text
)
assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text)
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -240,15 +245,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any
) -> LLMResult:
if isinstance(response, str):
content = response
else:
content = response.generated_text
assistant_prompt_message = AssistantPromptMessage(
content=content
)
assistant_prompt_message = AssistantPromptMessage(content=content)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
@ -270,15 +275,14 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
try:
if not model_info:
raise ValueError(f'Model {model_name} not found.')
raise ValueError(f"Model {model_name} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
if "inference" in model_info.cardData and not model_info.cardData["inference"]:
raise ValueError(f"Inference API has been turned off for this model {model_name}.")
valid_tasks = ("text2text-generation", "text-generation")
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}.")
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")
@ -287,10 +291,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message) for message in messages)
return text.rstrip()

View File

@ -13,40 +13,30 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
execute_model = model
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
execute_model = credentials['huggingfacehub_endpoint_url']
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
execute_model = credentials["huggingfacehub_endpoint_url"]
output = client.post(
json={
"inputs": texts,
"options": {
"wait_for_model": False,
"use_cache": False
}
},
model=execute_model)
json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model
)
embeddings = json.loads(output.decode())
tokens = self.get_num_tokens(model, credentials, texts)
usage = self._calc_response_usage(model, credentials, tokens)
return TextEmbeddingResult(
embeddings=self._mean_pooling(embeddings),
usage=usage,
model=model
)
return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
num_tokens = 0
@ -56,52 +46,48 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
if 'huggingfacehub_api_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.')
if "huggingfacehub_api_token" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.")
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingface_namespace' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.')
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
if "huggingface_namespace" not in credentials:
raise CredentialsValidateFailedError(
"Huggingface Hub User Name / Organization Name must be provided."
)
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
if "huggingfacehub_endpoint_url" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.")
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
if "task_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.")
if credentials['task_type'] != 'feature-extraction':
raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.')
if credentials["task_type"] != "feature-extraction":
raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.")
self._check_endpoint_url_model_repository_name(credentials, model)
model = credentials['huggingfacehub_endpoint_url']
model = credentials["huggingfacehub_endpoint_url"]
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'],
model)
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model)
else:
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
client.feature_extraction(text='hello world', model=model)
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
client.feature_extraction(text="hello world", model=model)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
'context_size': 10000,
'max_chunks': 1
}
model_properties={"context_size": 10000, "max_chunks": 1},
)
return entity
@ -128,24 +114,20 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
try:
if not model_info:
raise ValueError(f'Model {model_name} not found.')
raise ValueError(f"Model {model_name} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
if "inference" in model_info.cardData and not model_info.cardData["inference"]:
raise ValueError(f"Inference API has been turned off for this model {model_name}.")
valid_tasks = "feature-extraction"
if model_info.pipeline_tag not in valid_tasks:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {valid_tasks}.")
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -156,7 +138,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -166,25 +148,26 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
'Content-Type': 'application/json'
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
"Content-Type": "application/json",
}
response = requests.get(url=url, headers=headers)
if response.status_code != 200:
raise ValueError('User Name or Organization Name is invalid.')
raise ValueError("User Name or Organization Name is invalid.")
model_repository_name = ''
model_repository_name = ""
for item in response.json().get("items", []):
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]:
model_repository_name = item.get("model", {}).get("repository")
break
if model_repository_name != model_name:
raise ValueError(
f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
f"Model Name {model_name} is invalid. Please check it on the inference endpoints console."
)
except Exception as e:
raise ValueError(str(e))

View File

@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
class HuggingfaceTeiProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -47,29 +47,29 @@ class HuggingfaceTeiRerankModel(RerankModel):
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
server_url = credentials['server_url']
server_url = credentials["server_url"]
if server_url.endswith('/'):
if server_url.endswith("/"):
server_url = server_url[:-1]
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
rerank_documents = []
for result in results:
for result in results:
rerank_document = RerankDocument(
index=result['index'],
text=result['text'],
score=result['score'],
index=result["index"],
text=result["text"],
score=result["score"],
)
if score_threshold is None or result['score'] >= score_threshold:
if score_threshold is None or result["score"] >= score_threshold:
rerank_documents.append(rerank_document)
if top_n is not None and len(rerank_documents) >= top_n:
break
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -80,21 +80,21 @@ class HuggingfaceTeiRerankModel(RerankModel):
:return:
"""
try:
server_url = credentials['server_url']
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
if extra_args.model_type != 'reranker':
raise CredentialsValidateFailedError('Current model is not a rerank model')
if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model")
credentials['context_size'] = extra_args.max_input_length
credentials["context_size"] = extra_args.max_input_length
self.invoke(
model=model,
credentials=credentials,
query='Whose kasumi',
query="Whose kasumi",
docs=[
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
'and she leads a team named PopiParty.',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty.",
],
score_threshold=0.8,
)
@ -129,7 +129,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
},
parameter_rules=[],
)

View File

@ -31,16 +31,16 @@ class TeiHelper:
with cache_lock:
if model_name not in cache:
cache[model_name] = {
'expires': time() + 300,
'value': TeiHelper._get_tei_extra_parameter(server_url),
"expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url),
}
return cache[model_name]['value']
return cache[model_name]["value"]
@staticmethod
def _clean_cache() -> None:
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
@ -52,40 +52,38 @@ class TeiHelper:
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
url = str(URL(server_url) / 'info')
url = str(URL(server_url) / "info")
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
session.mount("http://", HTTPAdapter(max_retries=3))
session.mount("https://", HTTPAdapter(max_retries=3))
try:
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200:
raise RuntimeError(
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}"
)
response_json = response.json()
model_type = response_json.get('model_type', {})
model_type = response_json.get("model_type", {})
if len(model_type.keys()) < 1:
raise RuntimeError('model_type is empty')
raise RuntimeError("model_type is empty")
model_type = list(model_type.keys())[0]
if model_type not in ['embedding', 'reranker']:
raise RuntimeError(f'invalid model_type: {model_type}')
max_input_length = response_json.get('max_input_length', 512)
max_client_batch_size = response_json.get('max_client_batch_size', 1)
if model_type not in ["embedding", "reranker"]:
raise RuntimeError(f"invalid model_type: {model_type}")
max_input_length = response_json.get("max_input_length", 512)
max_client_batch_size = response_json.get("max_client_batch_size", 1)
return TeiModelExtraParameter(
model_type=model_type,
max_input_length=max_input_length,
max_client_batch_size=max_client_batch_size
model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size
)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
"""
@ -116,12 +114,12 @@ class TeiHelper:
:param texts: texts to tokenize
"""
resp = httpx.post(
f'{server_url}/tokenize',
json={'inputs': texts},
f"{server_url}/tokenize",
json={"inputs": texts},
)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
"""
@ -149,8 +147,8 @@ class TeiHelper:
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f'{server_url}/v1/embeddings',
json={'input': texts},
f"{server_url}/v1/embeddings",
json={"input": texts},
)
resp.raise_for_status()
return resp.json()
@ -173,11 +171,11 @@ class TeiHelper:
:param texts: texts to rerank
:param candidates: candidates to rerank
"""
params = {'query': query, 'texts': docs, 'return_text': True}
params = {"query": query, "texts": docs, "return_text": True}
response = httpx.post(
server_url + '/rerank',
server_url + "/rerank",
json=params,
)
response.raise_for_status()
response.raise_for_status()
return response.json()

View File

@ -40,12 +40,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
server_url = credentials['server_url']
server_url = credentials["server_url"]
if server_url.endswith('/'):
if server_url.endswith("/"):
server_url = server_url[:-1]
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
@ -58,7 +57,6 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
num_tokens = len(tokenize_result)
@ -66,20 +64,22 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
# Find the best cutoff point
pre_special_token_count = 0
for token in tokenize_result:
if token['special']:
if token["special"]:
pre_special_token_count += 1
else:
break
rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
rest_special_token_count = (
len([token for token in tokenize_result if token["special"]]) - pre_special_token_count
)
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
token_cutoff = context_size - rest_special_token_count - 20
# Find the cutoff index
cutpoint_token = tokenize_result[token_cutoff]
cutoff = cutpoint_token['start']
cutoff = cutpoint_token["start"]
inputs.append(text[0: cutoff])
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
@ -92,12 +92,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
embeddings = results['data']
embeddings = [embedding['embedding'] for embedding in embeddings]
embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings)
usage = results['usage']
used_tokens += usage['total_tokens']
usage = results["usage"]
used_tokens += usage["total_tokens"]
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
@ -117,9 +117,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
num_tokens = 0
server_url = credentials['server_url']
server_url = credentials["server_url"]
if server_url.endswith('/'):
if server_url.endswith("/"):
server_url = server_url[:-1]
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
@ -135,15 +135,15 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
server_url = credentials['server_url']
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
print(extra_args)
if extra_args.model_type != 'embedding':
raise CredentialsValidateFailedError('Current model is not a embedding model')
if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model")
credentials['context_size'] = extra_args.max_input_length
credentials['max_chunks'] = extra_args.max_client_batch_size
self._invoke(model=model, credentials=credentials, texts=['ping'])
credentials["context_size"] = extra_args.max_input_length
credentials["max_chunks"] = extra_args.max_client_batch_size
self._invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -195,8 +195,8 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
},
parameter_rules=[],
)

View File

@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class HunyuanProvider(ModelProvider):
class HunyuanProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -19,12 +19,9 @@ class HunyuanProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `hunyuan-standard` model for validate,
model_instance.validate_credentials(
model='hunyuan-standard',
credentials=credentials
)
model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -23,21 +23,27 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
class HunyuanLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
client = self._setup_hunyuan_client(credentials)
request = models.ChatCompletionsRequest()
messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages)
custom_parameters = {
'Temperature': model_parameters.get('temperature', 0.0),
'TopP': model_parameters.get('top_p', 1.0),
'EnableEnhancement': model_parameters.get('enable_enhance', True)
"Temperature": model_parameters.get("temperature", 0.0),
"TopP": model_parameters.get("top_p", 1.0),
"EnableEnhancement": model_parameters.get("enable_enhance", True),
}
params = {
@ -47,16 +53,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
**custom_parameters,
}
# add Tools and ToolChoice
if (tools and len(tools) > 0):
params['ToolChoice'] = "auto"
params['Tools'] = [{
"Type": "function",
"Function": {
"Name": tool.name,
"Description": tool.description,
"Parameters": json.dumps(tool.parameters)
if tools and len(tools) > 0:
params["ToolChoice"] = "auto"
params["Tools"] = [
{
"Type": "function",
"Function": {
"Name": tool.name,
"Description": tool.description,
"Parameters": json.dumps(tool.parameters),
},
}
} for tool in tools]
for tool in tools
]
request.from_json_string(json.dumps(params))
response = client.ChatCompletions(request)
@ -76,22 +85,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
req = models.ChatCompletionsRequest()
params = {
"Model": model,
"Messages": [{
"Role": "user",
"Content": "hello"
}],
"Messages": [{"Role": "user", "Content": "hello"}],
"TopP": 1,
"Temperature": 0,
"Stream": False
"Stream": False,
}
req.from_json_string(json.dumps(params))
client.ChatCompletions(req)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
def _setup_hunyuan_client(self, credentials):
secret_id = credentials['secret_id']
secret_key = credentials['secret_key']
secret_id = credentials["secret_id"]
secret_key = credentials["secret_key"]
cred = credential.Credential(secret_id, secret_key)
httpProfile = HttpProfile()
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
@ -106,92 +112,96 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
for message in prompt_messages:
if isinstance(message, AssistantPromptMessage):
tool_calls = message.tool_calls
if (tool_calls and len(tool_calls) > 0):
if tool_calls and len(tool_calls) > 0:
dict_tool_calls = [
{
"Id": tool_call.id,
"Type": tool_call.type,
"Function": {
"Name": tool_call.function.name,
"Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}"
}
} for tool_call in tool_calls]
dict_list.append({
"Role": message.role.value,
# fix set content = "" while tool_call request
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
"Content": " ", # message.content if (message.content is not None) else "",
"ToolCalls": dict_tool_calls
})
"Arguments": tool_call.function.arguments
if (tool_call.function.arguments == "")
else "{}",
},
}
for tool_call in tool_calls
]
dict_list.append(
{
"Role": message.role.value,
# fix set content = "" while tool_call request
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
"Content": " ", # message.content if (message.content is not None) else "",
"ToolCalls": dict_tool_calls,
}
)
else:
dict_list.append({ "Role": message.role.value, "Content": message.content })
dict_list.append({"Role": message.role.value, "Content": message.content})
elif isinstance(message, ToolPromptMessage):
tool_execute_result = { "result": message.content }
content =json.dumps(tool_execute_result, ensure_ascii=False)
dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id })
tool_execute_result = {"result": message.content}
content = json.dumps(tool_execute_result, ensure_ascii=False)
dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id})
else:
dict_list.append({ "Role": message.role.value, "Content": message.content })
dict_list.append({"Role": message.role.value, "Content": message.content})
return dict_list
def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
tool_call = None
tool_calls = []
for index, event in enumerate(resp):
logging.debug("_handle_stream_chat_response, event: %s", event)
data_str = event['data']
data_str = event["data"]
data = json.loads(data_str)
choices = data.get('Choices', [])
choices = data.get("Choices", [])
if not choices:
continue
choice = choices[0]
delta = choice.get('Delta', {})
message_content = delta.get('Content', '')
finish_reason = choice.get('FinishReason', '')
delta = choice.get("Delta", {})
message_content = delta.get("Content", "")
finish_reason = choice.get("FinishReason", "")
usage = data.get('Usage', {})
prompt_tokens = usage.get('PromptTokens', 0)
completion_tokens = usage.get('CompletionTokens', 0)
usage = data.get("Usage", {})
prompt_tokens = usage.get("PromptTokens", 0)
completion_tokens = usage.get("CompletionTokens", 0)
response_tool_calls = delta.get('ToolCalls')
if (response_tool_calls is not None):
response_tool_calls = delta.get("ToolCalls")
if response_tool_calls is not None:
new_tool_calls = self._extract_response_tool_calls(response_tool_calls)
if (len(new_tool_calls) > 0):
if len(new_tool_calls) > 0:
new_tool_call = new_tool_calls[0]
if (tool_call is None): tool_call = new_tool_call
elif (tool_call.id != new_tool_call.id):
if tool_call is None:
tool_call = new_tool_call
elif tool_call.id != new_tool_call.id:
tool_calls.append(tool_call)
tool_call = new_tool_call
else:
tool_call.function.name += new_tool_call.function.name
tool_call.function.arguments += new_tool_call.function.arguments
if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0):
if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0:
tool_calls.append(tool_call)
tool_call = None
assistant_prompt_message = AssistantPromptMessage(
content=message_content,
tool_calls=[]
)
assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[])
# rewrite content = "" while tool_call to avoid show content on web page
if (len(tool_calls) > 0): assistant_prompt_message.content = ""
if len(tool_calls) > 0:
assistant_prompt_message.content = ""
# add tool_calls to assistant_prompt_message
if (finish_reason == 'tool_calls'):
if finish_reason == "tool_calls":
assistant_prompt_message.tool_calls = tool_calls
tool_call = None
tool_calls = []
if (len(finish_reason) > 0):
if len(finish_reason) > 0:
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
delta_chunk = LLMResultChunkDelta(
index=index,
role=delta.get('Role', 'assistant'),
role=delta.get("Role", "assistant"),
message=assistant_prompt_message,
usage=usage,
finish_reason=finish_reason,
@ -212,8 +222,9 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
)
def _handle_chat_response(self, credentials, model, prompt_messages, response):
usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens,
response.Usage.CompletionTokens)
usage = self._calc_response_usage(
model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens
)
assistant_prompt_message = AssistantPromptMessage()
assistant_prompt_message.content = response.Choices[0].Message.Content
result = LLMResult(
@ -225,8 +236,13 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
return result
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
if len(prompt_messages) == 0:
return 0
prompt = self._convert_messages_to_prompt(prompt_messages)
@ -241,10 +257,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message) for message in messages)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
@ -287,10 +300,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
return {
InvokeError: [TencentCloudSDKException],
}
def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -300,17 +311,14 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
response_function = response_tool_call.get('Function', {})
response_function = response_tool_call.get("Function", {})
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function.get('Name', ''),
arguments=response_function.get('Arguments', '')
name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "")
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get('Id', 0),
type='function',
function=function
id=response_tool_call.get("Id", 0), type="function", function=function
)
tool_calls.append(tool_call)
return tool_calls
return tool_calls

View File

@ -19,14 +19,15 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
logger = logging.getLogger(__name__)
class HunyuanTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Hunyuan text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -37,9 +38,9 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings result
"""
if model != 'hunyuan-embedding':
raise ValueError('Invalid model name')
if model != "hunyuan-embedding":
raise ValueError("Invalid model name")
client = self._setup_hunyuan_client(credentials)
embeddings = []
@ -47,9 +48,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
for input in texts:
request = models.GetEmbeddingRequest()
params = {
"Input": input
}
params = {"Input": input}
request.from_json_string(json.dumps(params))
response = client.GetEmbedding(request)
usage = response.Usage.TotalTokens
@ -60,11 +59,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
)
return result
@ -79,22 +74,19 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
req = models.ChatCompletionsRequest()
params = {
"Model": model,
"Messages": [{
"Role": "user",
"Content": "hello"
}],
"Messages": [{"Role": "user", "Content": "hello"}],
"TopP": 1,
"Temperature": 0,
"Stream": False
"Stream": False,
}
req.from_json_string(json.dumps(params))
client.ChatCompletions(req)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
def _setup_hunyuan_client(self, credentials):
secret_id = credentials['secret_id']
secret_key = credentials['secret_key']
secret_id = credentials["secret_id"]
secret_key = credentials["secret_key"]
cred = credential.Credential(secret_id, secret_key)
httpProfile = HttpProfile()
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
@ -102,7 +94,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
clientProfile.httpProfile = httpProfile
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
return client
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -114,10 +106,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -128,11 +117,11 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -146,7 +135,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
return {
InvokeError: [TencentCloudSDKException],
}
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -170,4 +159,4 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
# response = client.GetTokenCount(request)
# num_tokens += response.TokenCount
return num_tokens
return num_tokens

View File

@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class JinaProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -21,12 +20,9 @@ class JinaProvider(ModelProvider):
# Use `jina-embeddings-v2-base-en` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials=credentials
)
model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel):
Model class for Jina rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -40,37 +47,32 @@ class JinaRerankModel(RerankModel):
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
if base_url.endswith('/'):
base_url = credentials.get("base_url", "https://api.jina.ai/v1")
if base_url.endswith("/"):
base_url = base_url[:-1]
try:
response = httpx.post(
base_url + '/rerank',
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
base_url + "/rerank",
json={"model": model, "query": query, "documents": docs, "top_n": top_n},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
)
response.raise_for_status()
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
for result in results["results"]:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
index=result["index"],
text=result["document"]["text"],
score=result["relevance_score"],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -81,7 +83,6 @@ class JinaRerankModel(RerankModel):
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
@ -92,7 +93,7 @@ class JinaRerankModel(RerankModel):
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -105,23 +106,21 @@ class JinaRerankModel(RerankModel):
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
}
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity
return entity

View File

@ -14,19 +14,19 @@ class JinaTokenizer:
with cls._lock:
if cls._tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
gpt2_tokenizer_path = join(dirname(base_path), "tokenizer")
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
return cls._tokenizer
@classmethod
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
"""
use jina tokenizer to get num tokens
use jina tokenizer to get num tokens
"""
tokenizer = cls._get_tokenizer()
tokens = tokenizer.encode(text)
return len(tokens)
@classmethod
def get_num_tokens(cls, text: str) -> int:
return cls._get_num_tokens_by_jina_base(text)
return cls._get_num_tokens_by_jina_base(text)

View File

@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
api_base: str = 'https://api.jina.ai/v1'
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "https://api.jina.ai/v1"
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -38,29 +39,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
api_key = credentials["api_key"]
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
raise CredentialsValidateFailedError("api_key is required")
base_url = credentials.get('base_url', self.api_base)
if base_url.endswith('/'):
base_url = credentials.get("base_url", self.api_base)
if base_url.endswith("/"):
base_url = base_url[:-1]
url = base_url + '/embeddings'
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
url = base_url + "/embeddings"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
def transform_jina_input_text(model, text):
if model == 'jina-clip-v1':
if model == "jina-clip-v1":
return {"text": text}
return text
data = {
'model': model,
'input': [transform_jina_input_text(model, text) for text in texts]
}
data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]}
try:
response = post(url, headers=headers, data=dumps(data))
@ -70,7 +65,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
if response.status_code != 200:
try:
resp = response.json()
msg = resp['detail']
msg = resp["detail"]
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
@ -81,25 +76,20 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
raise InvokeBadRequestError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}")
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=usage['total_tokens'])
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
result = TextEmbeddingResult(
model=model,
embeddings=[[
float(data) for data in x['embedding']
] for x in embeddings],
usage=usage
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
)
return result
@ -128,30 +118,18 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as e:
raise CredentialsValidateFailedError(
f'Credentials validation failed: {e}')
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError,
InvokeBadRequestError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
@ -165,10 +143,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -179,24 +154,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(
credentials.get('context_size'))
}
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class LeptonAIProvider(ModelProvider):
class LeptonAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -18,12 +18,9 @@ class LeptonAIProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama2-7b',
credentials=credentials
)
model_instance.validate_credentials(model="llama2-7b", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -8,18 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
MODEL_PREFIX_MAP = {
'llama2-7b': 'llama2-7b',
'gemma-7b': 'gemma-7b',
'mistral-7b': 'mistral-7b',
'mixtral-8x7b': 'mixtral-8x7b',
'llama3-70b': 'llama3-70b',
'llama2-13b': 'llama2-13b',
}
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"llama2-7b": "llama2-7b",
"gemma-7b": "gemma-7b",
"mistral-7b": "mistral-7b",
"mixtral-8x7b": "mixtral-8x7b",
"llama3-70b": "llama3-70b",
"llama2-13b": "llama2-13b",
}
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials, model)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
@ -29,6 +36,5 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
@classmethod
def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1'
credentials["mode"] = "chat"
credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1"

View File

@ -52,29 +52,48 @@ from core.model_runtime.utils import helper
class LocalAILanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages, tools=tools)
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for baichuan model
LocalAI does not supports
Calculate num tokens for baichuan model
LocalAI does not supports
"""
def tokens(text: str):
"""
We could not determine which tokenizer to use, cause the model is customized.
So we use gpt2 tokenizer to calculate the num tokens for convenience.
We could not determine which tokenizer to use, cause the model is customized.
So we use gpt2 tokenizer to calculate the num tokens for convenience.
"""
return self._get_num_tokens_by_gpt2(text)
@ -87,10 +106,10 @@ class LocalAILanguageModel(LargeLanguageModel):
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
text = ""
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
@ -142,30 +161,30 @@ class LocalAILanguageModel(LargeLanguageModel):
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens('name')
num_tokens += tokens("name")
num_tokens += tokens(tool.name)
num_tokens += tokens('description')
num_tokens += tokens("description")
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens('parameters')
num_tokens += tokens('type')
num_tokens += tokens("parameters")
num_tokens += tokens("type")
num_tokens += tokens(parameters.get("type"))
if 'properties' in parameters:
num_tokens += tokens('properties')
for key, value in parameters.get('properties').items():
if "properties" in parameters:
num_tokens += tokens("properties")
for key, value in parameters.get("properties").items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == 'enum':
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if 'required' in parameters:
num_tokens += tokens('required')
for required_field in parameters['required']:
if "required" in parameters:
num_tokens += tokens("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += tokens(required_field)
@ -180,102 +199,104 @@ class LocalAILanguageModel(LargeLanguageModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, prompt_messages=[
UserPromptMessage(content='ping')
], model_parameters={
'max_tokens': 10,
}, stop=[], stream=False)
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
"max_tokens": 10,
},
stop=[],
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
completion_model = None
if credentials['completion_type'] == 'chat_completion':
if credentials["completion_type"] == "chat_completion":
completion_model = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
elif credentials["completion_type"] == "completion":
completion_model = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
rules = [
ParameterRule(
name='temperature',
name="temperature",
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
)
use_template="temperature",
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
),
ParameterRule(
name='top_p',
name="top_p",
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
use_template="top_p",
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
),
ParameterRule(
name='max_tokens',
name="max_tokens",
type=ParameterType.INT,
use_template='max_tokens',
use_template="max_tokens",
min=1,
max=2048,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
),
]
model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}
model_properties = (
{
ModelPropertyKey.MODE: completion_model,
}
if completion_model
else {}
)
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048"))
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules
parameter_rules=rules,
)
return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
kwargs = self._to_client_kwargs(credentials)
# init model client
client = OpenAI(**kwargs)
model_name = model
completion_type = credentials['completion_type']
completion_type = credentials["completion_type"]
extra_model_kwargs = {
"timeout": 60,
}
if stop:
extra_model_kwargs['stop'] = stop
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs["user"] = user
if tools and len(tools) > 0:
extra_model_kwargs['functions'] = [
helper.dump_model(tool) for tool in tools
]
extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools]
if completion_type == 'chat_completion':
if completion_type == "chat_completion":
result = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model_name,
@ -283,36 +304,32 @@ class LocalAILanguageModel(LargeLanguageModel):
**model_parameters,
**extra_model_kwargs,
)
elif completion_type == 'completion':
elif completion_type == "completion":
result = client.completions.create(
prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages),
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs
**extra_model_kwargs,
)
else:
raise ValueError(f"Unknown completion type {completion_type}")
if stream:
if completion_type == 'completion':
if completion_type == "completion":
return self._handle_completion_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
)
return self._handle_chat_generate_stream_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
)
if completion_type == 'completion':
if completion_type == "completion":
return self._handle_completion_generate_response(
model=model, credentials=credentials, response=result,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, prompt_messages=prompt_messages
)
return self._handle_chat_generate_response(
model=model, credentials=credentials, response=result, tools=tools,
prompt_messages=prompt_messages
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
)
def _to_client_kwargs(self, credentials: dict) -> dict:
@ -322,13 +339,13 @@ class LocalAILanguageModel(LargeLanguageModel):
:param credentials: credentials dict
:return: client kwargs
"""
if not credentials['server_url'].endswith('/'):
credentials['server_url'] += '/'
if not credentials["server_url"].endswith("/"):
credentials["server_url"] += "/"
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": str(URL(credentials['server_url']) / 'v1'),
"base_url": str(URL(credentials["server_url"]) / "v1"),
}
return client_kwargs
@ -349,7 +366,7 @@ class LocalAILanguageModel(LargeLanguageModel):
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
@ -359,11 +376,7 @@ class LocalAILanguageModel(LargeLanguageModel):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
"content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}],
}
else:
raise ValueError(f"Unknown message type {type(message)}")
@ -374,27 +387,29 @@ class LocalAILanguageModel(LargeLanguageModel):
"""
Convert PromptMessage to completion prompts
"""
prompts = ''
prompts = ""
for message in messages:
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
prompts += f'{message.content}\n'
prompts += f"{message.content}\n"
else:
raise ValueError(f"Unknown message type {type(message)}")
return prompts
def _handle_completion_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
def _handle_completion_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
"""
Handle llm chat response
@ -411,18 +426,16 @@ class LocalAILanguageModel(LargeLanguageModel):
assistant_message = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message,
tool_calls=[]
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[])
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_completion_prompts(prompt_messages)
)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
response = LLMResult(
model=model,
@ -434,11 +447,14 @@ class LocalAILanguageModel(LargeLanguageModel):
return response
def _handle_chat_generate_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: ChatCompletion,
tools: list[PromptMessageTool]) -> LLMResult:
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: ChatCompletion,
tools: list[PromptMessageTool],
) -> LLMResult:
"""
Handle llm chat response
@ -459,16 +475,14 @@ class LocalAILanguageModel(LargeLanguageModel):
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
response = LLMResult(
model=model,
@ -480,12 +494,15 @@ class LocalAILanguageModel(LargeLanguageModel):
return response
def _handle_completion_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[Completion],
tools: list[PromptMessageTool]) -> Generator:
full_response = ''
def _handle_completion_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[Completion],
tools: list[PromptMessageTool],
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -494,17 +511,11 @@ class LocalAILanguageModel(LargeLanguageModel):
delta = chunk.choices[0]
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.text if delta.text else '',
tool_calls=[]
)
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
)
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_completion_prompts(prompt_messages)
@ -512,8 +523,12 @@ class LocalAILanguageModel(LargeLanguageModel):
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk(
model=model,
@ -523,7 +538,7 @@ class LocalAILanguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -539,12 +554,15 @@ class LocalAILanguageModel(LargeLanguageModel):
full_response += delta.text
def _handle_chat_generate_stream_response(self, model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
tools: list[PromptMessageTool]) -> Generator:
full_response = ''
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
tools: list[PromptMessageTool],
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
@ -552,7 +570,7 @@ class LocalAILanguageModel(LargeLanguageModel):
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
continue
# check if there is a tool call in the response
@ -564,22 +582,24 @@ class LocalAILanguageModel(LargeLanguageModel):
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_message_tool_calls
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=assistant_message_tool_calls
content=full_response, tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk(
model=model,
@ -589,7 +609,7 @@ class LocalAILanguageModel(LargeLanguageModel):
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
usage=usage,
),
)
else:
@ -605,9 +625,9 @@ class LocalAILanguageModel(LargeLanguageModel):
full_response += delta.delta.content
def _extract_response_tool_calls(self,
response_function_calls: list[FunctionCall]) \
-> list[AssistantPromptMessage.ToolCall]:
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -618,15 +638,10 @@ class LocalAILanguageModel(LargeLanguageModel):
if response_function_calls:
for response_tool_call in response_function_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.name,
arguments=response_tool_call.arguments
name=response_tool_call.name, arguments=response_tool_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=0,
type='function',
function=function
)
tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function)
tool_calls.append(tool_call)
return tool_calls
@ -651,15 +666,9 @@ class LocalAILanguageModel(LargeLanguageModel):
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError
PermissionDeniedError,
],
InvokeRateLimitError: [
RateLimitError
],
InvokeAuthorizationError: [
AuthenticationError
],
InvokeBadRequestError: [
ValueError
]
InvokeRateLimitError: [RateLimitError],
InvokeAuthorizationError: [AuthenticationError],
InvokeBadRequestError: [ValueError],
}

View File

@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
class LocalAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass
pass

View File

@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel):
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -43,45 +50,37 @@ class LocalaiRerankModel(RerankModel):
if len(docs) == 0:
return RerankResult(model=model, docs=[])
server_url = credentials['server_url']
server_url = credentials["server_url"]
model_name = model
if not server_url:
raise CredentialsValidateFailedError('server_url is required')
if not model_name:
raise CredentialsValidateFailedError('model_name is required')
url = server_url
headers = {
'Authorization': f"Bearer {credentials.get('api_key')}",
'Content-Type': 'application/json'
}
data = {
"model": model_name,
"query": query,
"documents": docs,
"top_n": top_n
}
if not server_url:
raise CredentialsValidateFailedError("server_url is required")
if not model_name:
raise CredentialsValidateFailedError("model_name is required")
url = server_url
headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n}
try:
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
response.raise_for_status()
response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
for result in results["results"]:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
index=result["index"],
text=result["document"]["text"],
score=result["relevance_score"],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -92,7 +91,6 @@ class LocalaiRerankModel(RerankModel):
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
@ -103,7 +101,7 @@ class LocalaiRerankModel(RerankModel):
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -116,21 +114,21 @@ class LocalaiRerankModel(RerankModel):
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={}
model_properties={},
)
return entity

View File

@ -32,8 +32,8 @@ class LocalAISpeech2text(Speech2TextModel):
:param user: unique user id
:return: text for given audio file
"""
url = str(URL(credentials['server_url']) / "v1/audio/transcriptions")
url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions")
data = {"model": model}
files = {"file": file}
@ -42,7 +42,7 @@ class LocalAISpeech2text(Speech2TextModel):
prepared_request = session.prepare_request(request)
response = session.send(prepared_request)
if 'error' in response.json():
if "error" in response.json():
raise InvokeServerUnavailableError("Empty response")
return response.json()["text"]
@ -58,7 +58,7 @@ class LocalAISpeech2text(Speech2TextModel):
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -66,36 +66,24 @@ class LocalAISpeech2text(Speech2TextModel):
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError
],
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={},
parameter_rules=[]
parameter_rules=[],
)
return entity
return entity

View File

@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -37,39 +38,33 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
:return: embeddings result
"""
if len(texts) != 1:
raise InvokeBadRequestError('Only one text is supported')
raise InvokeBadRequestError("Only one text is supported")
server_url = credentials['server_url']
server_url = credentials["server_url"]
model_name = model
if not server_url:
raise CredentialsValidateFailedError('server_url is required')
raise CredentialsValidateFailedError("server_url is required")
if not model_name:
raise CredentialsValidateFailedError('model_name is required')
url = server_url
headers = {
'Authorization': 'Bearer 123',
'Content-Type': 'application/json'
}
raise CredentialsValidateFailedError("model_name is required")
data = {
'model': model_name,
'input': texts[0]
}
url = server_url
headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"}
data = {"model": model_name, "input": texts[0]}
try:
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10)
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
code = resp['error']['code']
msg = resp['error']['message']
code = resp["error"]["code"]
msg = resp["error"]["message"]
if code == 500:
raise InvokeServerUnavailableError(msg)
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
@ -79,23 +74,21 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
else:
raise InvokeError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
result = TextEmbeddingResult(
model=model,
embeddings=[[
float(data) for data in x['embedding']
] for x in embeddings],
usage=usage
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
)
return result
@ -114,7 +107,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
# use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema
@ -130,10 +123,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
parameter_rules=[],
)
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -145,32 +138,22 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvokeAuthorizationError:
raise CredentialsValidateFailedError('Invalid credentials')
raise CredentialsValidateFailedError("Invalid credentials")
except InvokeConnectionError as e:
raise CredentialsValidateFailedError(f'Invalid credentials: {e}')
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -182,10 +165,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -196,7 +176,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -17,42 +17,48 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxChatCompletion:
"""
Minimax Chat Completion API
Minimax Chat Completion API
"""
def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: list[MinimaxMessage], model_parameters: dict,
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
def generate(
self,
model: str,
api_key: str,
group_id: str,
prompt_messages: list[MinimaxMessage],
model_parameters: dict,
tools: list[dict[str, Any]],
stop: list[str] | None,
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
generate chat completion
"""
if not api_key or not group_id:
raise InvalidAPIKeyError('Invalid API key or group ID')
url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
raise InvalidAPIKeyError("Invalid API key or group ID")
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
extra_kwargs = {}
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
extra_kwargs['temperature'] = model_parameters['temperature']
if "temperature" in model_parameters and type(model_parameters["temperature"]) == float:
extra_kwargs["temperature"] = model_parameters["temperature"]
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
extra_kwargs['top_p'] = model_parameters['top_p']
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
extra_kwargs["top_p"] = model_parameters["top_p"]
prompt = '你是一个什么都懂的专家'
prompt = "你是一个什么都懂的专家"
role_meta = {
'user_name': '',
'bot_name': '专家'
}
role_meta = {"user_name": "", "bot_name": "专家"}
# check if there is a system message
if len(prompt_messages) == 0:
raise BadRequestError('At least one message is required')
raise BadRequestError("At least one message is required")
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
if prompt_messages[0].content:
prompt = prompt_messages[0].content
@ -60,40 +66,39 @@ class MinimaxChatCompletion:
# check if there is a user message
if len(prompt_messages) == 0:
raise BadRequestError('At least one user message is required')
messages = [{
'sender_type': message.role,
'text': message.content,
} for message in prompt_messages]
raise BadRequestError("At least one user message is required")
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
messages = [
{
"sender_type": message.role,
"text": message.content,
}
for message in prompt_messages
]
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
body = {
'model': model,
'messages': messages,
'prompt': prompt,
'role_meta': role_meta,
'stream': stream,
**extra_kwargs
"model": model,
"messages": messages,
"prompt": prompt,
"role_meta": role_meta,
"stream": stream,
**extra_kwargs,
}
try:
response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
except Exception as e:
raise InternalServerError(e)
if response.status_code != 200:
raise InternalServerError(response.text)
if stream:
return self._handle_stream_chat_generate_response(response)
return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
raise InternalServerError(msg)
@ -110,65 +115,52 @@ class MinimaxChatCompletion:
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
"""
handle chat generate response
handle chat generate response
"""
response = response.json()
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
code = response['base_resp']['status_code']
msg = response['base_resp']['status_msg']
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(
content=response['reply'],
role=MinimaxMessage.Role.ASSISTANT.value
)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
'prompt_tokens': 0,
'completion_tokens': response['usage']['total_tokens'],
'total_tokens': response['usage']['total_tokens']
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
}
message.stop_reason = response['choices'][0]['finish_reason']
message.stop_reason = response["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
"""
handle stream chat generate response
handle stream chat generate response
"""
for line in response.iter_lines():
if not line:
continue
line: str = line.decode('utf-8')
if line.startswith('data: '):
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
code = data['base_resp']['status_code']
msg = data['base_resp']['status_msg']
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
msg = data["base_resp"]["status_msg"]
self._handle_error(code, msg)
if data['reply']:
total_tokens = data['usage']['total_tokens']
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.usage = {
'prompt_tokens': 0,
'completion_tokens': total_tokens,
'total_tokens': total_tokens
}
message.stop_reason = data['choices'][0]['finish_reason']
if data["reply"]:
total_tokens = data["usage"]["total_tokens"]
message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens}
message.stop_reason = data["choices"][0]["finish_reason"]
yield message
return
choices = data.get('choices', [])
choices = data.get("choices", [])
if len(choices) == 0:
continue
for choice in choices:
message = choice['delta']
yield MinimaxMessage(
content=message,
role=MinimaxMessage.Role.ASSISTANT.value
)
message = choice["delta"]
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)

View File

@ -17,86 +17,83 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxChatCompletionPro:
"""
Minimax Chat Completion Pro API, supports function calling
however, we do not have enough time and energy to implement it, but the parameters are reserved
Minimax Chat Completion Pro API, supports function calling
however, we do not have enough time and energy to implement it, but the parameters are reserved
"""
def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: list[MinimaxMessage], model_parameters: dict,
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
def generate(
self,
model: str,
api_key: str,
group_id: str,
prompt_messages: list[MinimaxMessage],
model_parameters: dict,
tools: list[dict[str, Any]],
stop: list[str] | None,
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
generate chat completion
"""
if not api_key or not group_id:
raise InvalidAPIKeyError('Invalid API key or group ID')
raise InvalidAPIKeyError("Invalid API key or group ID")
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
extra_kwargs = {}
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
extra_kwargs['temperature'] = model_parameters['temperature']
if "temperature" in model_parameters and type(model_parameters["temperature"]) == float:
extra_kwargs["temperature"] = model_parameters["temperature"]
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
extra_kwargs['top_p'] = model_parameters['top_p']
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
extra_kwargs["top_p"] = model_parameters["top_p"]
if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool:
extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info']
if model_parameters.get('plugin_web_search'):
extra_kwargs['plugins'] = [
'plugin_web_search'
]
if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool:
extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"]
bot_setting = {
'bot_name': '专家',
'content': '你是一个什么都懂的专家'
}
if model_parameters.get("plugin_web_search"):
extra_kwargs["plugins"] = ["plugin_web_search"]
reply_constraints = {
'sender_type': 'BOT',
'sender_name': '专家'
}
bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"}
reply_constraints = {"sender_type": "BOT", "sender_name": "专家"}
# check if there is a system message
if len(prompt_messages) == 0:
raise BadRequestError('At least one message is required')
raise BadRequestError("At least one message is required")
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
if prompt_messages[0].content:
bot_setting['content'] = prompt_messages[0].content
bot_setting["content"] = prompt_messages[0].content
prompt_messages = prompt_messages[1:]
# check if there is a user message
if len(prompt_messages) == 0:
raise BadRequestError('At least one user message is required')
raise BadRequestError("At least one user message is required")
messages = [message.to_dict() for message in prompt_messages]
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
body = {
'model': model,
'messages': messages,
'bot_setting': [bot_setting],
'reply_constraints': reply_constraints,
'stream': stream,
**extra_kwargs
"model": model,
"messages": messages,
"bot_setting": [bot_setting],
"reply_constraints": reply_constraints,
"stream": stream,
**extra_kwargs,
}
if tools:
body['functions'] = tools
body['function_call'] = {'type': 'auto'}
body["functions"] = tools
body["function_call"] = {"type": "auto"}
try:
response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
except Exception as e:
raise InternalServerError(e)
@ -123,78 +120,72 @@ class MinimaxChatCompletionPro:
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
"""
handle chat generate response
handle chat generate response
"""
response = response.json()
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
code = response['base_resp']['status_code']
msg = response['base_resp']['status_msg']
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(
content=response['reply'],
role=MinimaxMessage.Role.ASSISTANT.value
)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
'prompt_tokens': 0,
'completion_tokens': response['usage']['total_tokens'],
'total_tokens': response['usage']['total_tokens']
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
}
message.stop_reason = response['choices'][0]['finish_reason']
message.stop_reason = response["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
"""
handle stream chat generate response
handle stream chat generate response
"""
for line in response.iter_lines():
if not line:
continue
line: str = line.decode('utf-8')
if line.startswith('data: '):
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
code = data['base_resp']['status_code']
msg = data['base_resp']['status_msg']
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
msg = data["base_resp"]["status_msg"]
self._handle_error(code, msg)
# final chunk
if data['reply'] or data.get('usage'):
total_tokens = data['usage']['total_tokens']
minimax_message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
if data["reply"] or data.get("usage"):
total_tokens = data["usage"]["total_tokens"]
minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
minimax_message.usage = {
'prompt_tokens': 0,
'completion_tokens': total_tokens,
'total_tokens': total_tokens
"prompt_tokens": 0,
"completion_tokens": total_tokens,
"total_tokens": total_tokens,
}
minimax_message.stop_reason = data['choices'][0]['finish_reason']
minimax_message.stop_reason = data["choices"][0]["finish_reason"]
choices = data.get('choices', [])
choices = data.get("choices", [])
if len(choices) > 0:
for choice in choices:
message = choice['messages'][0]
message = choice["messages"][0]
# append function_call message
if 'function_call' in message:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = message['function_call']
if "function_call" in message:
function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = message["function_call"]
yield function_call_message
yield minimax_message
return
# partial chunk
choices = data.get('choices', [])
choices = data.get("choices", [])
if len(choices) == 0:
continue
for choice in choices:
message = choice['messages'][0]
message = choice["messages"][0]
# append text message
if 'text' in message:
minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
if "text" in message:
minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value)
yield minimax_message

View File

@ -1,17 +1,22 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalanceError(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass
pass

View File

@ -34,18 +34,25 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
'abab6.5s-chat': MinimaxChatCompletionPro,
'abab6.5-chat': MinimaxChatCompletionPro,
'abab6-chat': MinimaxChatCompletionPro,
'abab5.5s-chat': MinimaxChatCompletionPro,
'abab5.5-chat': MinimaxChatCompletionPro,
'abab5-chat': MinimaxChatCompletion
"abab6.5s-chat": MinimaxChatCompletionPro,
"abab6.5-chat": MinimaxChatCompletionPro,
"abab6-chat": MinimaxChatCompletionPro,
"abab5.5s-chat": MinimaxChatCompletionPro,
"abab5.5-chat": MinimaxChatCompletionPro,
"abab5-chat": MinimaxChatCompletion,
}
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -53,82 +60,97 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
Validate credentials for Baichuan model
"""
if model not in self.model_apis:
raise CredentialsValidateFailedError(f'Invalid model: {model}')
raise CredentialsValidateFailedError(f"Invalid model: {model}")
if not credentials.get('minimax_api_key'):
raise CredentialsValidateFailedError('Invalid API key')
if not credentials.get("minimax_api_key"):
raise CredentialsValidateFailedError("Invalid API key")
if not credentials.get("minimax_group_id"):
raise CredentialsValidateFailedError("Invalid group ID")
if not credentials.get('minimax_group_id'):
raise CredentialsValidateFailedError('Invalid group ID')
# ping
instance = MinimaxChatCompletionPro()
try:
instance.generate(
model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'],
prompt_messages=[
MinimaxMessage(content='ping', role='USER')
],
model=model,
api_key=credentials["minimax_api_key"],
group_id=credentials["minimax_group_id"],
prompt_messages=[MinimaxMessage(content="ping", role="USER")],
model_parameters={},
tools=[], stop=[],
tools=[],
stop=[],
stream=False,
user=''
user="",
)
except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for minimax model
Calculate num tokens for minimax model
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
to calculate the num tokens, so we use str() to convert the prompt to string
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
to calculate the num tokens, so we use str() to convert the prompt to string
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
therefore, we use gpt2 tokenizer instead
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
therefore, we use gpt2 tokenizer instead
"""
messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages]
return self._get_num_tokens_by_gpt2(str(messages_dict))
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
"""
client: MinimaxChatCompletionPro = self.model_apis[model]()
if tools:
tools = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
tools = [
{"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
]
response = client.generate(
model=model,
api_key=credentials['minimax_api_key'],
group_id=credentials['minimax_group_id'],
api_key=credentials["minimax_api_key"],
group_id=credentials["minimax_group_id"],
prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages],
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
user=user,
)
if stream:
return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
return self._handle_chat_generate_stream_response(
model=model, prompt_messages=prompt_messages, credentials=credentials, response=response
)
return self._handle_chat_generate_response(
model=model, prompt_messages=prompt_messages, credentials=credentials, response=response
)
def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage:
"""
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
"""
if isinstance(prompt_message, SystemPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content)
@ -136,26 +158,27 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.function_call={
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
message.function_call = {
"name": prompt_message.tool_calls[0].function.name,
"arguments": prompt_message.tool_calls[0].function.arguments,
}
return message
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
elif isinstance(prompt_message, ToolPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
else:
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported")
def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult:
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=response.usage['prompt_tokens'],
completion_tokens=response.usage['completion_tokens']
)
def _handle_chat_generate_response(
self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage
) -> LLMResult:
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=response.usage["prompt_tokens"],
completion_tokens=response.usage["completion_tokens"],
)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
@ -166,31 +189,33 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
usage=usage,
)
def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage],
credentials: dict, response: Generator[MinimaxMessage, None, None]) \
-> Generator[LLMResultChunk, None, None]:
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Generator[MinimaxMessage, None, None],
) -> Generator[LLMResultChunk, None, None]:
for message in response:
if message.usage:
usage = self._calc_response_usage(
model=model, credentials=credentials,
prompt_tokens=message.usage['prompt_tokens'],
completion_tokens=message.usage['completion_tokens']
model=model,
credentials=credentials,
prompt_tokens=message.usage["prompt_tokens"],
completion_tokens=message.usage["completion_tokens"],
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
),
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
),
)
elif message.function_call:
if 'name' not in message.function_call or 'arguments' not in message.function_call:
if "name" not in message.function_call or "arguments" not in message.function_call:
continue
yield LLMResultChunk(
@ -199,15 +224,16 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content='',
tool_calls=[AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call['name'],
arguments=message.function_call['arguments']
content="",
tool_calls=[
AssistantPromptMessage.ToolCall(
id="",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call["name"], arguments=message.function_call["arguments"]
),
)
)]
],
),
),
)
@ -217,10 +243,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=message.content,
tool_calls=[]
),
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None,
),
)
@ -236,22 +259,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}

View File

@ -4,32 +4,27 @@ from typing import Any
class MinimaxMessage:
class Role(Enum):
USER = 'USER'
ASSISTANT = 'BOT'
SYSTEM = 'SYSTEM'
FUNCTION = 'FUNCTION'
USER = "USER"
ASSISTANT = "BOT"
SYSTEM = "SYSTEM"
FUNCTION = "FUNCTION"
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
stop_reason: str = ''
stop_reason: str = ""
function_call: dict[str, Any] = None
def to_dict(self) -> dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
'sender_type': 'BOT',
'sender_name': '专家',
'text': '',
'function_call': self.function_call
}
return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call}
return {
'sender_type': self.role,
'sender_name': '' if self.role == 'USER' else '专家',
'text': self.content,
"sender_type": self.role,
"sender_name": "" if self.role == "USER" else "专家",
"text": self.content,
}
def __init__(self, content: str, role: str = 'USER') -> None:
def __init__(self, content: str, role: str = "USER") -> None:
self.content = content
self.role = role
self.role = role

View File

@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
logger = logging.getLogger(__name__)
class MinimaxProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@ -19,12 +20,9 @@ class MinimaxProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `abab5.5-chat` model for validate,
model_instance.validate_credentials(
model='abab5.5-chat',
credentials=credentials
)
model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise CredentialsValidateFailedError(f'{ex}')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise CredentialsValidateFailedError(f"{ex}")

View File

@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Minimax text embedding model.
"""
api_base: str = 'https://api.minimax.chat/v1/embeddings'
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "https://api.minimax.chat/v1/embeddings"
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -44,54 +45,43 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['minimax_api_key']
group_id = credentials['minimax_group_id']
if model != 'embo-01':
raise ValueError('Invalid model name')
api_key = credentials["minimax_api_key"]
group_id = credentials["minimax_group_id"]
if model != "embo-01":
raise ValueError("Invalid model name")
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
url = f'{self.api_base}?GroupId={group_id}'
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
raise CredentialsValidateFailedError("api_key is required")
url = f"{self.api_base}?GroupId={group_id}"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
data = {
'model': 'embo-01',
'texts': texts,
'type': 'db'
}
data = {"model": "embo-01", "texts": texts, "type": "db"}
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
raise InvokeServerUnavailableError(response.text)
try:
resp = response.json()
# check if there is an error
if resp['base_resp']['status_code'] != 0:
code = resp['base_resp']['status_code']
msg = resp['base_resp']['status_msg']
if resp["base_resp"]["status_code"] != 0:
code = resp["base_resp"]["status_code"]
msg = resp["base_resp"]["status_msg"]
self._handle_error(code, msg)
embeddings = resp['vectors']
total_tokens = resp['total_tokens']
embeddings = resp["vectors"]
total_tokens = resp["total_tokens"]
except InvalidAuthenticationError:
raise InvalidAPIKeyError('Invalid api key')
raise InvalidAPIKeyError("Invalid api key")
except KeyError as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens)
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=usage
)
result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
return result
@ -119,9 +109,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvalidAPIKeyError:
raise CredentialsValidateFailedError('Invalid api key')
raise CredentialsValidateFailedError("Invalid api key")
def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001:
@ -148,25 +138,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
InvokeBadRequestError: [BadRequestError, KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -178,10 +160,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -192,7 +171,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -7,14 +7,19 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# mistral dose not support user/stop arguments
stop = []
user = None
@ -27,5 +32,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.mistral.ai/v1"

View File

@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class MistralAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='open-mistral-7b',
credentials=credentials
)
model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
@ -49,50 +55,50 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get('function_calling_type') == 'tool_call'
else [],
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name='temperature',
use_template='temperature',
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name='max_tokens',
use_template='max_tokens',
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get('max_tokens', 4096)),
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
max=int(credentials.get("max_tokens", 4096)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name='top_p',
use_template='top_p',
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
]
],
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
credentials["mode"] = "chat"
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
}.intersection(model_schema.features or []):
credentials['function_calling_type'] = 'tool_call'
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
model_schema.features or []
):
credentials["function_calling_type"] = "tool_call"
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
"""
@ -107,19 +113,13 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
@ -129,14 +129,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append({
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments
message_dict["tool_calls"].append(
{
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
}
})
)
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
@ -162,21 +164,26 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
name=response_tool_call["function"]["name"]
if response_tool_call.get("function", {}).get("name")
else "",
arguments=response_tool_call["function"]["arguments"]
if response_tool_call.get("function", {}).get("arguments")
else "",
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function
function=function,
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@ -186,11 +193,12 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
full_assistant_content = ""
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
def create_final_llm_result_chunk(
index: int, message: AssistantPromptMessage, finish_reason: str
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
@ -201,12 +209,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
@ -220,9 +223,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
@ -244,9 +247,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
@ -255,21 +258,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
finish_reason="Non-JSON encountered.",
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
if "delta" in choice:
delta = choice["delta"]
delta_content = delta.get("content")
assistant_message_tool_calls = delta.get('tool_calls', None)
assistant_message_tool_calls = delta.get("tool_calls", None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
@ -277,19 +280,18 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
if delta_content is None or delta_content == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
@ -305,26 +307,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
)

View File

@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class MoonshotProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -19,12 +18,9 @@ class MoonshotProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='moonshot-v1-8k',
credentials=credentials
)
model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -8,20 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _update_endpoint_url(self, credentials: dict):
credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
credentials["endpoint_url"] = "https://api.novita.ai/v3/openai"
credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"}
return credentials
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
self._add_custom_parameters(credentials, model)
@ -29,21 +34,36 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
@classmethod
def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
credentials['mode'] = 'chat'
credentials["mode"] = "chat"
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
return super()._generate(
model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user
)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_customizable_model_schema(model, cred_with_endpoint)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

View File

@ -20,12 +20,9 @@ class NovitaProvider(ModelProvider):
# Use `meta-llama/llama-3-8b-instruct` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(
model='meta-llama/llama-3-8b-instruct',
credentials=credentials
)
model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -21,31 +21,36 @@ from core.model_runtime.utils import helper
class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
MODEL_SUFFIX_MAP = {
'fuyu-8b': 'vlm/adept/fuyu-8b',
'mistralai/mistral-large': '',
'mistralai/mixtral-8x7b-instruct-v0.1': '',
'mistralai/mixtral-8x22b-instruct-v0.1': '',
'google/gemma-7b': '',
'google/codegemma-7b': '',
'snowflake/arctic':'',
'meta/llama2-70b': '',
'meta/llama3-8b-instruct': '',
'meta/llama3-70b-instruct': '',
'meta/llama-3.1-8b-instruct': '',
'meta/llama-3.1-70b-instruct': '',
'meta/llama-3.1-405b-instruct': '',
'google/recurrentgemma-2b': '',
'nvidia/nemotron-4-340b-instruct': '',
'microsoft/phi-3-medium-128k-instruct':'',
'microsoft/phi-3-mini-128k-instruct':''
"fuyu-8b": "vlm/adept/fuyu-8b",
"mistralai/mistral-large": "",
"mistralai/mixtral-8x7b-instruct-v0.1": "",
"mistralai/mixtral-8x22b-instruct-v0.1": "",
"google/gemma-7b": "",
"google/codegemma-7b": "",
"snowflake/arctic": "",
"meta/llama2-70b": "",
"meta/llama3-8b-instruct": "",
"meta/llama3-70b-instruct": "",
"meta/llama-3.1-8b-instruct": "",
"meta/llama-3.1-70b-instruct": "",
"meta/llama-3.1-405b-instruct": "",
"google/recurrentgemma-2b": "",
"nvidia/nemotron-4-340b-instruct": "",
"microsoft/phi-3-medium-128k-instruct": "",
"microsoft/phi-3-mini-128k-instruct": "",
}
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials, model)
prompt_messages = self._transform_prompt_messages(prompt_messages)
stop = []
@ -60,16 +65,14 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
for i, p in enumerate(prompt_messages):
if isinstance(p, UserPromptMessage) and isinstance(p.content, list):
content = p.content
content_text = ''
content_text = ""
for prompt_content in content:
if prompt_content.type == PromptMessageContentType.TEXT:
content_text += prompt_content.data
else:
content_text += f' <img src="{prompt_content.data}" />'
prompt_message = UserPromptMessage(
content=content_text
)
prompt_message = UserPromptMessage(content=content_text)
prompt_messages[i] = prompt_message
return prompt_messages
@ -78,15 +81,15 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
self._validate_credentials(model, credentials)
def _add_custom_parameters(self, credentials: dict, model: str) -> None:
credentials['mode'] = 'chat'
if self.MODEL_SUFFIX_MAP[model]:
credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}'
credentials.pop('endpoint_url')
else:
credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1'
credentials["mode"] = "chat"
credentials['stream_mode_delimiter'] = '\n'
if self.MODEL_SUFFIX_MAP[model]:
credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}"
credentials.pop("endpoint_url")
else:
credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1"
credentials["stream_mode_delimiter"] = "\n"
def _validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -97,72 +100,67 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
:return:
"""
try:
headers = {
'Content-Type': 'application/json'
}
headers = {"Content-Type": "application/json"}
api_key = credentials.get('api_key')
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if endpoint_url and not endpoint_url.endswith('/'):
endpoint_url += '/'
server_url = credentials.get('server_url')
endpoint_url = credentials.get("endpoint_url")
if endpoint_url and not endpoint_url.endswith("/"):
endpoint_url += "/"
server_url = credentials.get("server_url")
# prepare the payload for a simple ping to the model
data = {
'model': model,
'max_tokens': 5
}
data = {"model": model, "max_tokens": 5}
completion_type = LLMMode.value_of(credentials['mode'])
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
data['messages'] = [
{
"role": "user",
"content": "ping"
},
data["messages"] = [
{"role": "user", "content": "ping"},
]
if 'endpoint_url' in credentials:
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
elif 'server_url' in credentials:
if "endpoint_url" in credentials:
endpoint_url = str(URL(endpoint_url) / "chat" / "completions")
elif "server_url" in credentials:
endpoint_url = server_url
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
if 'endpoint_url' in credentials:
endpoint_url = str(URL(endpoint_url) / 'completions')
elif 'server_url' in credentials:
data["prompt"] = "ping"
if "endpoint_url" in credentials:
endpoint_url = str(URL(endpoint_url) / "completions")
elif "server_url" in credentials:
endpoint_url = server_url
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
headers=headers,
json=data,
timeout=(10, 300)
)
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
if response.status_code != 200:
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
f"Credentials validation failed with status code {response.status_code}"
)
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
@ -176,57 +174,51 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
:return: full response or stream response chunk generator result
"""
headers = {
'Content-Type': 'application/json',
'Accept-Charset': 'utf-8',
"Content-Type": "application/json",
"Accept-Charset": "utf-8",
}
api_key = credentials.get('api_key')
api_key = credentials.get("api_key")
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
headers["Authorization"] = f"Bearer {api_key}"
if stream:
headers['Accept'] = 'text/event-stream'
headers["Accept"] = "text/event-stream"
endpoint_url = credentials.get('endpoint_url')
if endpoint_url and not endpoint_url.endswith('/'):
endpoint_url += '/'
server_url = credentials.get('server_url')
endpoint_url = credentials.get("endpoint_url")
if endpoint_url and not endpoint_url.endswith("/"):
endpoint_url += "/"
server_url = credentials.get("server_url")
data = {
"model": model,
"stream": stream,
**model_parameters
}
data = {"model": model, "stream": stream, **model_parameters}
completion_type = LLMMode.value_of(credentials['mode'])
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
if 'endpoint_url' in credentials:
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
elif 'server_url' in credentials:
if "endpoint_url" in credentials:
endpoint_url = str(URL(endpoint_url) / "chat" / "completions")
elif "server_url" in credentials:
endpoint_url = server_url
data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
if 'endpoint_url' in credentials:
endpoint_url = str(URL(endpoint_url) / 'completions')
elif 'server_url' in credentials:
data["prompt"] = "ping"
if "endpoint_url" in credentials:
endpoint_url = str(URL(endpoint_url) / "completions")
elif "server_url" in credentials:
endpoint_url = server_url
else:
raise ValueError("Unsupported completion type for model configuration.")
# annotate tools with names, descriptions, etc.
function_calling_type = credentials.get('function_calling_type', 'no_call')
function_calling_type = credentials.get("function_calling_type", "no_call")
formatted_tools = []
if tools:
if function_calling_type == 'function_call':
data['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
elif function_calling_type == 'tool_call':
if function_calling_type == "function_call":
data["functions"] = [
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
for tool in tools
]
elif function_calling_type == "tool_call":
data["tool_choice"] = "auto"
for tool in tools:
@ -240,16 +232,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
if user:
data["user"] = user
response = requests.post(
endpoint_url,
headers=headers,
json=data,
timeout=(10, 300),
stream=stream
)
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
if response.encoding is None or response.encoding == 'ISO-8859-1':
response.encoding = 'utf-8'
if response.encoding is None or response.encoding == "ISO-8859-1":
response.encoding = "utf-8"
if not response.ok:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")

View File

@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class MistralAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='mistralai/mixtral-8x7b-instruct-v0.1',
credentials=credentials
)
model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel):
"""
def _sigmoid(self, logit: float) -> float:
return 1/(1+exp(-logit))
return 1 / (1 + exp(-logit))
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
@ -60,9 +67,9 @@ class NvidiaRerankModel(RerankModel):
results = response.json()
rerank_documents = []
for result in results['rankings']:
index = result['index']
logit = result['logit']
for result in results["rankings"]:
index = result["index"]
logit = result["logit"]
rerank_document = RerankDocument(
index=index,
text=docs[index],
@ -110,5 +117,5 @@ class NvidiaRerankModel(RerankModel):
InvokeServerUnavailableError: [requests.HTTPError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [requests.HTTPError],
InvokeBadRequestError: [requests.RequestException]
InvokeBadRequestError: [requests.RequestException],
}

View File

@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Nvidia text embedding model.
"""
api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings'
models: list[str] = ['NV-Embed-QA']
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
models: list[str] = ["NV-Embed-QA"]
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -37,32 +38,25 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
api_key = credentials["api_key"]
if model not in self.models:
raise InvokeBadRequestError('Invalid model name')
raise InvokeBadRequestError("Invalid model name")
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
raise CredentialsValidateFailedError("api_key is required")
url = self.api_base
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
data = {
'model': model,
'input': texts[0],
'input_type': 'query'
}
data = {"model": model, "input": texts[0], "input_type": "query"}
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
msg = resp['detail']
msg = resp["detail"]
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
@ -72,23 +66,21 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
else:
raise InvokeError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp['data']
usage = resp['usage']
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
result = TextEmbeddingResult(
model=model,
embeddings=[[
float(data) for data in x['embedding']
] for x in embeddings],
usage=usage
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
)
return result
@ -117,30 +109,20 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
self._invoke(model=model, credentials=credentials, texts=["ping"])
except InvokeAuthorizationError:
raise CredentialsValidateFailedError('Invalid api key')
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@ -152,10 +134,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -166,7 +145,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage

View File

@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel):
"""
Model class for NVIDIA NIM large language model.
"""
pass

View File

@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
class NVIDIANIMProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -33,31 +33,29 @@ logger = logging.getLogger(__name__)
request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.command-r-plus",
"servingType": "ON_DEMAND"
},
"servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"},
"chatRequest": {
"apiFormat": "COHERE",
#"preambleOverride": "You are a helpful assistant.",
#"message": "Hello!",
#"chatHistory": [],
# "preambleOverride": "You are a helpful assistant.",
# "message": "Hello!",
# "chatHistory": [],
"maxTokens": 600,
"isStream": False,
"frequencyPenalty": 0,
"presencePenalty": 0,
"temperature": 1,
"topP": 0.75
}
"topP": 0.75,
},
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": "",
}
class OCILargeLanguageModel(LargeLanguageModel):
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
@ -100,11 +98,17 @@ class OCILargeLanguageModel(LargeLanguageModel):
return False
return feature["system"]
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -118,22 +122,27 @@ class OCILargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
#print("model"+"*"*20)
#print(model)
#print("credentials"+"*"*20)
#print(credentials)
#print("model_parameters"+"*"*20)
#print(model_parameters)
#print("prompt_messages"+"*"*200)
#print(prompt_messages)
#print("tools"+"*"*20)
#print(tools)
# print("model"+"*"*20)
# print(model)
# print("credentials"+"*"*20)
# print(credentials)
# print("model_parameters"+"*"*20)
# print(model_parameters)
# print("prompt_messages"+"*"*200)
# print(prompt_messages)
# print("tools"+"*"*20)
# print(tools)
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -147,8 +156,13 @@ class OCILargeLanguageModel(LargeLanguageModel):
return self._get_num_tokens_by_gpt2(prompt)
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_characters(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -169,10 +183,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)
text = "".join(self._convert_one_message_to_text(message) for message in messages)
return text.rstrip()
@ -192,11 +203,17 @@ class OCILargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -218,10 +235,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8")
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
raise CredentialsValidateFailedError(
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
)
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
@ -230,12 +249,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8")
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
compartment_id = oci_config["compartment_id"]
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
# call embedding model
@ -245,9 +264,9 @@ class OCILargeLanguageModel(LargeLanguageModel):
chat_history = []
system_prompts = []
#if "meta.llama" in model:
# if "meta.llama" in model:
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600)
request_args["chatRequest"].update(model_parameters)
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
presence_penalty = model_parameters.get("presencePenalty", 0)
@ -267,7 +286,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
if not valid_value:
raise InvokeBadRequestError("Does not support function calling")
if model.startswith("cohere"):
#print("run cohere " * 10)
# print("run cohere " * 10)
for message in prompt_messages[:-1]:
text = ""
if isinstance(message.content, str):
@ -279,37 +298,37 @@ class OCILargeLanguageModel(LargeLanguageModel):
if isinstance(message, SystemPromptMessage):
if isinstance(message.content, str):
system_prompts.append(message.content)
args = {"apiFormat": "COHERE",
"preambleOverride": ' '.join(system_prompts),
"message": prompt_messages[-1].content,
"chatHistory": chat_history, }
args = {
"apiFormat": "COHERE",
"preambleOverride": " ".join(system_prompts),
"message": prompt_messages[-1].content,
"chatHistory": chat_history,
}
request_args["chatRequest"].update(args)
elif model.startswith("meta"):
#print("run meta " * 10)
# print("run meta " * 10)
meta_messages = []
for message in prompt_messages:
text = message.content
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
args = {"apiFormat": "GENERIC",
"messages": meta_messages,
"numGenerations": 1,
"topK": -1}
args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1}
request_args["chatRequest"].update(args)
if stream:
request_args["chatRequest"]["isStream"] = True
#print("final request" + "|" * 20)
#print(request_args)
# print("final request" + "|" * 20)
# print(request_args)
response = client.chat(request_args)
#print(vars(response))
# print(vars(response))
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@ -320,9 +339,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.data.chat_response.text
)
assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text)
# calculate num tokens
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
@ -341,8 +358,9 @@ class OCILargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@ -356,14 +374,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
events = response.data.events()
for stream in events:
chunk = json.loads(stream.data)
#print(chunk)
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
# print(chunk)
# chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
#for chunk in response:
#for part in chunk.parts:
#if part.function_call:
# for chunk in response:
# for part in chunk.parts:
# if part.function_call:
# assistant_prompt_message.tool_calls = [
# AssistantPromptMessage.ToolCall(
# id=part.function_call.name,
@ -376,9 +392,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
# ]
if "finishReason" not in chunk:
assistant_prompt_message = AssistantPromptMessage(
content=''
)
assistant_prompt_message = AssistantPromptMessage(content="")
if model.startswith("cohere"):
if chunk["text"]:
assistant_prompt_message.content += chunk["text"]
@ -389,10 +403,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# calculate num tokens
@ -409,8 +420,8 @@ class OCILargeLanguageModel(LargeLanguageModel):
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk["finishReason"]),
usage=usage
)
usage=usage,
),
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
@ -425,9 +436,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(
c.data for c in content if c.type != PromptMessageContentType.IMAGE
)
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@ -457,5 +466,5 @@ class OCILargeLanguageModel(LargeLanguageModel):
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
InvokeBadRequestError: [],
}

View File

@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
class OCIGENAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
@ -21,14 +20,9 @@ class OCIGENAIProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
# Use `cohere.command-r-plus` model for validate,
model_instance.validate_credentials(
model='cohere.command-r-plus',
credentials=credentials
)
model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -21,29 +21,28 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.embed-english-light-v3.0",
"servingType": "ON_DEMAND"
},
"servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"},
"truncate": "NONE",
"inputs": [""]
"inputs": [""],
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": "",
}
class OCITextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
@ -62,14 +61,13 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
@ -80,26 +78,16 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
for i in _iter:
# call embedding model
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=inputs[i: i + max_chunks]
model=model, credentials=credentials, texts=inputs[i : i + max_chunks]
)
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
@ -125,6 +113,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
for text in texts:
characters += len(text)
return characters
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
@ -135,11 +124,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -157,10 +142,12 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
# initialize client
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8")
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
raise CredentialsValidateFailedError(
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
)
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
@ -169,7 +156,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8")
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
@ -195,10 +182,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
@ -209,7 +193,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
latency=time.perf_counter() - self.started_at,
)
return usage
@ -224,19 +208,9 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}

View File

@ -121,9 +121,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
text = ""
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(
TextPromptMessageContent, message_content
)
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
break
return self._get_num_tokens_by_gpt2(text)
@ -145,13 +143,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
stream=False,
)
except InvokeError as ex:
raise CredentialsValidateFailedError(
f"An error occurred during credentials validation: {ex.description}"
)
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}")
except Exception as ex:
raise CredentialsValidateFailedError(
f"An error occurred during credentials validation: {str(ex)}"
)
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
def _generate(
self,
@ -201,9 +195,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [
self._convert_prompt_message_to_dict(m) for m in prompt_messages
]
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
else:
endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0]
@ -216,14 +208,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
images = []
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(
TextPromptMessageContent, message_content
)
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content
)
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(
r"^data:image\/[a-zA-Z]+;base64,",
"",
@ -235,24 +223,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
data["images"] = images
# send a post request to validate the credentials
response = requests.post(
endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
)
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
response.encoding = "utf-8"
if response.status_code != 200:
raise InvokeError(
f"API request failed with status code {response.status_code}: {response.text}"
)
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
if stream:
return self._handle_generate_stream_response(
model, credentials, completion_type, response, prompt_messages
)
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(
model, credentials, completion_type, response, prompt_messages
)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
def _handle_generate_response(
self,
@ -292,9 +272,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
# transform usage
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
@ -335,9 +313,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
@ -394,15 +370,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_tokens = chunk_json["eval_count"]
else:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(
prompt_messages[0].content
)
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk_json["model"],
@ -439,17 +411,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
images = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(
TextPromptMessageContent, message_content
)
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content
)
image_data = re.sub(
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
)
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)
images.append(image_data)
message_dict = {"role": "user", "content": text, "images": images}
@ -479,9 +445,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
@ -502,9 +466,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: credentials.get("mode"),
ModelPropertyKey.CONTEXT_SIZE: int(
credentials.get("context_size", 4096)
),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
},
parameter_rules=[
ParameterRule(
@ -568,9 +530,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
en_US="Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)"
),
default=(
512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
),
default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128),
min=-2,
max=int(credentials.get("max_tokens", 4096)),
),
@ -612,22 +572,23 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
label=I18nObject(en_US="Size of context window"),
type=ParameterType.INT,
help=I18nObject(
en_US="Sets the size of the context window used to generate the next token. "
"(Default: 2048)"
en_US="Sets the size of the context window used to generate the next token. " "(Default: 2048)"
),
default=2048,
min=1,
),
ParameterRule(
name='num_gpu',
name="num_gpu",
label=I18nObject(en_US="GPU Layers"),
type=ParameterType.INT,
help=I18nObject(en_US="The number of layers to offload to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."
"As long as a model fits into one gpu it stays in one. "
"It does not set the number of GPU(s). "),
help=I18nObject(
en_US="The number of layers to offload to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."
"As long as a model fits into one gpu it stays in one. "
"It does not set the number of GPU(s). "
),
min=-1,
default=1
default=1,
),
ParameterRule(
name="num_thread",
@ -688,8 +649,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
label=I18nObject(en_US="Format"),
type=ParameterType.STRING,
help=I18nObject(
en_US="the format to return a response in."
" Currently the only accepted value is json."
en_US="the format to return a response in." " Currently the only accepted value is json."
),
options=["json"],
),

Some files were not shown because too many files have changed in this diff Show More