mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@ -18,12 +18,21 @@ class Callback:
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
"""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
@ -39,10 +48,19 @@ class Callback:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
@ -59,10 +77,19 @@ class Callback:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
@ -79,10 +106,19 @@ class Callback:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
@ -99,9 +135,7 @@ class Callback:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = ""
|
||||
) -> None:
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
@ -10,11 +10,20 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
@ -28,40 +37,49 @@ class LoggingCallback(Callback):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
|
||||
self.print_text(f"Model: {model}\n", color='blue')
|
||||
self.print_text("Parameters:\n", color='blue')
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
|
||||
self.print_text(f"Model: {model}\n", color="blue")
|
||||
self.print_text("Parameters:\n", color="blue")
|
||||
for key, value in model_parameters.items():
|
||||
self.print_text(f"\t{key}: {value}\n", color='blue')
|
||||
self.print_text(f"\t{key}: {value}\n", color="blue")
|
||||
|
||||
if stop:
|
||||
self.print_text(f"\tstop: {stop}\n", color='blue')
|
||||
self.print_text(f"\tstop: {stop}\n", color="blue")
|
||||
|
||||
if tools:
|
||||
self.print_text("\tTools:\n", color='blue')
|
||||
self.print_text("\tTools:\n", color="blue")
|
||||
for tool in tools:
|
||||
self.print_text(f"\t\t{tool.name}\n", color='blue')
|
||||
self.print_text(f"\t\t{tool.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"Stream: {stream}\n", color='blue')
|
||||
self.print_text(f"Stream: {stream}\n", color="blue")
|
||||
|
||||
if user:
|
||||
self.print_text(f"User: {user}\n", color='blue')
|
||||
self.print_text(f"User: {user}\n", color="blue")
|
||||
|
||||
self.print_text("Prompt messages:\n", color='blue')
|
||||
self.print_text("Prompt messages:\n", color="blue")
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.name:
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
|
||||
|
||||
if stream:
|
||||
self.print_text("\n[on_llm_new_chunk]")
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
@ -79,10 +97,19 @@ class LoggingCallback(Callback):
|
||||
sys.stdout.write(chunk.delta.message.content)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
@ -97,24 +124,33 @@ class LoggingCallback(Callback):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
|
||||
self.print_text(f"Content: {result.message.content}\n", color='yellow')
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
|
||||
self.print_text(f"Content: {result.message.content}\n", color="yellow")
|
||||
|
||||
if result.message.tool_calls:
|
||||
self.print_text("Tool calls:\n", color='yellow')
|
||||
self.print_text("Tool calls:\n", color="yellow")
|
||||
for tool_call in result.message.tool_calls:
|
||||
self.print_text(f"\t{tool_call.id}\n", color='yellow')
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
|
||||
self.print_text(f"\t{tool_call.id}\n", color="yellow")
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
|
||||
|
||||
self.print_text(f"Model: {result.model}\n", color='yellow')
|
||||
self.print_text(f"Usage: {result.usage}\n", color='yellow')
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
|
||||
self.print_text(f"Model: {result.model}\n", color="yellow")
|
||||
self.print_text(f"Usage: {result.usage}\n", color="yellow")
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
@ -129,5 +165,5 @@ class LoggingCallback(Callback):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color='red')
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color="red")
|
||||
logger.exception(ex)
|
||||
|
||||
@ -7,6 +7,7 @@ class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: Optional[str] = None
|
||||
en_US: str
|
||||
|
||||
|
||||
@ -2,123 +2,123 @@ from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
'label': {
|
||||
'en_US': 'Temperature',
|
||||
'zh_Hans': '温度',
|
||||
"label": {
|
||||
"en_US": "Temperature",
|
||||
"zh_Hans": "温度",
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
|
||||
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 2,
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
'label': {
|
||||
'en_US': 'Top P',
|
||||
'zh_Hans': 'Top P',
|
||||
"label": {
|
||||
"en_US": "Top P",
|
||||
"zh_Hans": "Top P",
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
|
||||
'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。',
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
|
||||
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 1.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 2,
|
||||
"required": False,
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
'label': {
|
||||
'en_US': 'Top K',
|
||||
'zh_Hans': 'Top K',
|
||||
"label": {
|
||||
"en_US": "Top K",
|
||||
"zh_Hans": "Top K",
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.',
|
||||
'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。',
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
|
||||
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 50,
|
||||
'min': 1,
|
||||
'max': 100,
|
||||
'precision': 0,
|
||||
"required": False,
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Presence Penalty',
|
||||
'zh_Hans': '存在惩罚',
|
||||
"label": {
|
||||
"en_US": "Presence Penalty",
|
||||
"zh_Hans": "存在惩罚",
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
|
||||
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
|
||||
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 2,
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Frequency Penalty',
|
||||
'zh_Hans': '频率惩罚',
|
||||
"label": {
|
||||
"en_US": "Frequency Penalty",
|
||||
"zh_Hans": "频率惩罚",
|
||||
},
|
||||
'type': 'float',
|
||||
'help': {
|
||||
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
|
||||
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
|
||||
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 2,
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
'label': {
|
||||
'en_US': 'Max Tokens',
|
||||
'zh_Hans': '最大标记',
|
||||
"label": {
|
||||
"en_US": "Max Tokens",
|
||||
"zh_Hans": "最大标记",
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.',
|
||||
'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。',
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
|
||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||
},
|
||||
'required': False,
|
||||
'default': 64,
|
||||
'min': 1,
|
||||
'max': 2048,
|
||||
'precision': 0,
|
||||
"required": False,
|
||||
"default": 64,
|
||||
"min": 1,
|
||||
"max": 2048,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
'label': {
|
||||
'en_US': 'Response Format',
|
||||
'zh_Hans': '回复格式',
|
||||
"label": {
|
||||
"en_US": "Response Format",
|
||||
"zh_Hans": "回复格式",
|
||||
},
|
||||
'type': 'string',
|
||||
'help': {
|
||||
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
|
||||
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
|
||||
"type": "string",
|
||||
"help": {
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
|
||||
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
||||
},
|
||||
'required': False,
|
||||
'options': ['JSON', 'XML'],
|
||||
"required": False,
|
||||
"options": ["JSON", "XML"],
|
||||
},
|
||||
DefaultParameterName.JSON_SCHEMA: {
|
||||
'label': {
|
||||
'en_US': 'JSON Schema',
|
||||
"label": {
|
||||
"en_US": "JSON Schema",
|
||||
},
|
||||
'type': 'text',
|
||||
'help': {
|
||||
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
|
||||
'zh_Hans': '设置返回的json schema,llm将按照它返回',
|
||||
"type": "text",
|
||||
"help": {
|
||||
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
||||
"zh_Hans": "设置返回的json schema,llm将按照它返回",
|
||||
},
|
||||
'required': False,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -12,11 +12,12 @@ class LLMMode(Enum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'LLMMode':
|
||||
def value_of(cls, value: str) -> "LLMMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -26,13 +27,14 @@ class LLMMode(Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_unit_price: Decimal
|
||||
prompt_price_unit: Decimal
|
||||
@ -50,20 +52,20 @@ class LLMUsage(ModelUsage):
|
||||
def empty_usage(cls):
|
||||
return cls(
|
||||
prompt_tokens=0,
|
||||
prompt_unit_price=Decimal('0.0'),
|
||||
prompt_price_unit=Decimal('0.0'),
|
||||
prompt_price=Decimal('0.0'),
|
||||
prompt_unit_price=Decimal("0.0"),
|
||||
prompt_price_unit=Decimal("0.0"),
|
||||
prompt_price=Decimal("0.0"),
|
||||
completion_tokens=0,
|
||||
completion_unit_price=Decimal('0.0'),
|
||||
completion_price_unit=Decimal('0.0'),
|
||||
completion_price=Decimal('0.0'),
|
||||
completion_unit_price=Decimal("0.0"),
|
||||
completion_price_unit=Decimal("0.0"),
|
||||
completion_price=Decimal("0.0"),
|
||||
total_tokens=0,
|
||||
total_price=Decimal('0.0'),
|
||||
currency='USD',
|
||||
latency=0.0
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
)
|
||||
|
||||
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
def plus(self, other: "LLMUsage") -> "LLMUsage":
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
@ -85,10 +87,10 @@ class LLMUsage(ModelUsage):
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency
|
||||
latency=self.latency + other.latency,
|
||||
)
|
||||
|
||||
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
def __add__(self, other: "LLMUsage") -> "LLMUsage":
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
@ -97,10 +99,12 @@ class LLMUsage(ModelUsage):
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
@ -112,6 +116,7 @@ class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: Optional[LLMUsage] = None
|
||||
@ -122,6 +127,7 @@ class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
system_fingerprint: Optional[str] = None
|
||||
@ -132,4 +138,5 @@ class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
|
||||
@ -9,13 +9,14 @@ class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptMessageRole':
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -25,13 +26,14 @@ class PromptMessageRole(Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt message type value {value}')
|
||||
raise ValueError(f"invalid prompt message type value {value}")
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel):
|
||||
"""
|
||||
Model class for prompt message function.
|
||||
"""
|
||||
type: str = 'function'
|
||||
|
||||
type: str = "function"
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
@ -49,14 +52,16 @@ class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
|
||||
|
||||
@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
|
||||
|
||||
@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
@ -123,7 +135,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
@field_validator('id', mode='before')
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
@ -145,10 +157,12 @@ class AssistantPromptMessage(PromptMessage):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
|
||||
|
||||
@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ class ModelType(Enum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
@ -26,22 +27,22 @@ class ModelType(Enum):
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
|
||||
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
|
||||
return cls.LLM
|
||||
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
|
||||
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
||||
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
|
||||
return cls.RERANK
|
||||
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
||||
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
|
||||
return cls.TTS
|
||||
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
|
||||
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f'invalid origin model type {origin_model_type}')
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
|
||||
def to_origin_model_type(self) -> str:
|
||||
"""
|
||||
@ -50,26 +51,28 @@ class ModelType(Enum):
|
||||
:return: origin model type
|
||||
"""
|
||||
if self == self.LLM:
|
||||
return 'text-generation'
|
||||
return "text-generation"
|
||||
elif self == self.TEXT_EMBEDDING:
|
||||
return 'embeddings'
|
||||
return "embeddings"
|
||||
elif self == self.RERANK:
|
||||
return 'reranking'
|
||||
return "reranking"
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return 'speech2text'
|
||||
return "speech2text"
|
||||
elif self == self.TTS:
|
||||
return 'tts'
|
||||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
return "moderation"
|
||||
elif self == self.TEXT2IMG:
|
||||
return 'text2img'
|
||||
return "text2img"
|
||||
else:
|
||||
raise ValueError(f'invalid model type {self}')
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
@ -78,6 +81,7 @@ class ModelFeature(Enum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
@ -89,6 +93,7 @@ class DefaultParameterName(str, Enum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
@ -99,7 +104,7 @@ class DefaultParameterName(str, Enum):
|
||||
JSON_SCHEMA = "json_schema"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
@ -109,13 +114,14 @@ class DefaultParameterName(str, Enum):
|
||||
for name in cls:
|
||||
if name.value == value:
|
||||
return name
|
||||
raise ValueError(f'invalid parameter name {value}')
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(Enum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = "float"
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
@ -127,6 +133,7 @@ class ModelPropertyKey(Enum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = "mode"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
MAX_CHUNKS = "max_chunks"
|
||||
@ -144,6 +151,7 @@ class ProviderModel(BaseModel):
|
||||
"""
|
||||
Model class for provider model.
|
||||
"""
|
||||
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
@ -158,6 +166,7 @@ class ParameterRule(BaseModel):
|
||||
"""
|
||||
Model class for parameter rule.
|
||||
"""
|
||||
|
||||
name: str
|
||||
use_template: Optional[str] = None
|
||||
label: I18nObject
|
||||
@ -175,6 +184,7 @@ class PriceConfig(BaseModel):
|
||||
"""
|
||||
Model class for pricing info.
|
||||
"""
|
||||
|
||||
input: Decimal
|
||||
output: Optional[Decimal] = None
|
||||
unit: Decimal
|
||||
@ -185,6 +195,7 @@ class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
Model class for AI model.
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: Optional[PriceConfig] = None
|
||||
|
||||
@ -197,6 +208,7 @@ class PriceType(Enum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
|
||||
@ -205,6 +217,7 @@ class PriceInfo(BaseModel):
|
||||
"""
|
||||
Model class for price info.
|
||||
"""
|
||||
|
||||
unit_price: Decimal
|
||||
unit: Decimal
|
||||
total_amount: Decimal
|
||||
|
||||
@ -12,6 +12,7 @@ class ConfigurateMethod(Enum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
@ -20,6 +21,7 @@ class FormType(Enum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = "select"
|
||||
@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel):
|
||||
"""
|
||||
Model class for form show on.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value: str
|
||||
|
||||
@ -39,6 +42,7 @@ class FormOption(BaseModel):
|
||||
"""
|
||||
Model class for form option.
|
||||
"""
|
||||
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
@ -46,15 +50,14 @@ class FormOption(BaseModel):
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if not self.label:
|
||||
self.label = I18nObject(
|
||||
en_US=self.value
|
||||
)
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
"""
|
||||
Model class for credential form schema.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for model credential schema.
|
||||
"""
|
||||
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider help.
|
||||
"""
|
||||
|
||||
title: I18nObject
|
||||
url: I18nObject
|
||||
|
||||
@ -110,6 +117,7 @@ class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
@ -138,7 +146,7 @@ class ProviderEntity(BaseModel):
|
||||
icon_small=self.icon_small,
|
||||
icon_large=self.icon_large,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
|
||||
@ -146,5 +154,6 @@ class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Model class for provider config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
credentials: dict
|
||||
|
||||
@ -5,6 +5,7 @@ class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
|
||||
index: int
|
||||
text: str
|
||||
score: float
|
||||
@ -14,5 +15,6 @@ class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
|
||||
@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
total_tokens: int
|
||||
unit_price: Decimal
|
||||
@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
class InvokeError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
@ -14,24 +15,29 @@ class InvokeError(Exception):
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
|
||||
description = "Connection Error"
|
||||
|
||||
|
||||
class InvokeServerUnavailableError(InvokeError):
|
||||
"""Raised when the Invoke returns server unavailable error."""
|
||||
|
||||
description = "Server Unavailable Error"
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class InvokeAuthorizationError(InvokeError):
|
||||
"""Raised when the Invoke returns authorization error."""
|
||||
|
||||
description = "Incorrect model credentials provided, please check and try again. "
|
||||
|
||||
|
||||
class InvokeBadRequestError(InvokeError):
|
||||
"""Raised when the Invoke returns bad request."""
|
||||
|
||||
description = "Bad Request Error"
|
||||
|
||||
@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -66,12 +66,14 @@ class AIModel(ABC):
|
||||
:param error: model invoke error
|
||||
:return: unified error
|
||||
"""
|
||||
provider_name = self.__class__.__module__.split('.')[-3]
|
||||
provider_name = self.__class__.__module__.split(".")[-3]
|
||||
|
||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
|
||||
return invoke_error(
|
||||
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
|
||||
)
|
||||
|
||||
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
||||
|
||||
@ -115,7 +117,7 @@ class AIModel(ABC):
|
||||
if not price_config:
|
||||
raise ValueError(f"Price config not found for model {model}")
|
||||
total_amount = tokens * unit_price * price_config.unit
|
||||
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
return PriceInfo(
|
||||
unit_price=unit_price,
|
||||
@ -136,24 +138,26 @@ class AIModel(ABC):
|
||||
model_schemas = []
|
||||
|
||||
# get module name
|
||||
model_type = self.__class__.__module__.split('.')[-1]
|
||||
model_type = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
# get provider name
|
||||
provider_name = self.__class__.__module__.split('.')[-3]
|
||||
provider_name = self.__class__.__module__.split(".")[-3]
|
||||
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
# get parent path of the current path
|
||||
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
|
||||
provider_model_type_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
|
||||
)
|
||||
|
||||
# get all yaml files path under provider_model_type_path that do not start with __
|
||||
model_schema_yaml_paths = [
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith('__')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
if not model_schema_yaml.startswith("__")
|
||||
and not model_schema_yaml.startswith("_")
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
and model_schema_yaml.endswith(".yaml")
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
@ -165,10 +169,10 @@ class AIModel(ABC):
|
||||
yaml_data = load_yaml_file(model_schema_yaml_path)
|
||||
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
||||
if 'use_template' in parameter_rule:
|
||||
for parameter_rule in yaml_data.get("parameter_rules", []):
|
||||
if "use_template" in parameter_rule:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"])
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
copy_default_parameter_rule = default_parameter_rule.copy()
|
||||
copy_default_parameter_rule.update(parameter_rule)
|
||||
@ -176,31 +180,26 @@ class AIModel(ABC):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if 'label' not in parameter_rule:
|
||||
parameter_rule['label'] = {
|
||||
'zh_Hans': parameter_rule['name'],
|
||||
'en_US': parameter_rule['name']
|
||||
}
|
||||
if "label" not in parameter_rule:
|
||||
parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]}
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
yaml_data['parameter_rules'] = new_parameter_rules
|
||||
yaml_data["parameter_rules"] = new_parameter_rules
|
||||
|
||||
if 'label' not in yaml_data:
|
||||
yaml_data['label'] = {
|
||||
'zh_Hans': yaml_data['model'],
|
||||
'en_US': yaml_data['model']
|
||||
}
|
||||
if "label" not in yaml_data:
|
||||
yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]}
|
||||
|
||||
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
|
||||
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
model_schema = AIModelEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
|
||||
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
|
||||
f' {str(e)}')
|
||||
raise Exception(
|
||||
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}"
|
||||
)
|
||||
|
||||
# cache model schema
|
||||
model_schemas.append(model_schema)
|
||||
@ -235,7 +234,9 @@ class AIModel(ABC):
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(
|
||||
self, model: str, credentials: Mapping
|
||||
) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@ -261,19 +262,19 @@ class AIModel(ABC):
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max and 'max' in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule['max']
|
||||
if not parameter_rule.min and 'min' in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule['min']
|
||||
if not parameter_rule.default and 'default' in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule['default']
|
||||
if not parameter_rule.precision and 'precision' in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule['precision']
|
||||
if not parameter_rule.required and 'required' in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule['required']
|
||||
if not parameter_rule.help and 'help' in default_parameter_rule:
|
||||
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule["max"]
|
||||
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule["min"]
|
||||
if not parameter_rule.default and "default" in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule["default"]
|
||||
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule["precision"]
|
||||
if not parameter_rule.required and "required" in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule["required"]
|
||||
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule['help']['en_US'],
|
||||
en_US=default_parameter_rule["help"]["en_US"],
|
||||
)
|
||||
if (
|
||||
parameter_rule.help
|
||||
|
||||
@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -69,7 +77,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
@ -82,7 +90,7 @@ class LargeLanguageModel(AIModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -96,7 +104,7 @@ class LargeLanguageModel(AIModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
@ -111,7 +119,7 @@ class LargeLanguageModel(AIModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
raise self._transform_invoke_error(e)
|
||||
@ -127,7 +135,7 @@ class LargeLanguageModel(AIModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif isinstance(result, LLMResult):
|
||||
self._trigger_after_invoke_callbacks(
|
||||
@ -140,15 +148,23 @@ class LargeLanguageModel(AIModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
|
||||
def _code_block_mode_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||
|
||||
@ -183,7 +199,7 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
@ -195,15 +211,16 @@ if you are not sure about the structure.
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", str(prompt_messages[0].content))
|
||||
content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content))
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
))
|
||||
prompt_messages.insert(
|
||||
0,
|
||||
SystemPromptMessage(
|
||||
content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
),
|
||||
)
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last text message
|
||||
@ -216,9 +233,7 @@ if you are not sure about the structure.
|
||||
break
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
@ -228,33 +243,30 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
first_chunk = next(response)
|
||||
|
||||
def new_generator():
|
||||
yield first_chunk
|
||||
yield from response
|
||||
|
||||
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
|
||||
)
|
||||
else:
|
||||
return self._code_block_mode_stream_processor(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
def _code_block_mode_stream_processor(
|
||||
self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
||||
|
||||
@ -303,16 +315,13 @@ if you are not sure about the structure.
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
|
||||
),
|
||||
)
|
||||
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
def _code_block_mode_stream_processor_with_backtick(
|
||||
self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
||||
This version skips the language identifier that follows the opening triple backticks.
|
||||
@ -378,18 +387,23 @@ if you are not sure about the structure.
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
|
||||
),
|
||||
)
|
||||
|
||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
@ -397,9 +411,7 @@ if you are not sure about the structure.
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
prompt_message = AssistantPromptMessage(
|
||||
content=""
|
||||
)
|
||||
prompt_message = AssistantPromptMessage(content="")
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
@ -418,7 +430,7 @@ if you are not sure about the structure.
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
prompt_message.content += chunk.delta.message.content
|
||||
@ -438,7 +450,7 @@ if you are not sure about the structure.
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -447,15 +459,21 @@ if you are not sure about the structure.
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -472,8 +490,13 @@ if you are not sure about the structure.
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -519,7 +542,9 @@ if you are not sure about the structure.
|
||||
|
||||
return mode
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
||||
def _calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
@ -539,10 +564,7 @@ if you are not sure about the structure.
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.OUTPUT,
|
||||
tokens=completion_tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -558,16 +580,23 @@ if you are not sure about the structure.
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
def _trigger_before_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
|
||||
@ -593,7 +622,7 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
@ -601,11 +630,19 @@ if you are not sure about the structure.
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
|
||||
|
||||
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
def _trigger_new_chunk_callbacks(
|
||||
self,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
|
||||
@ -632,7 +669,7 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
@ -640,11 +677,19 @@ if you are not sure about the structure.
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
|
||||
|
||||
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
def _trigger_after_invoke_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
result: LLMResult,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
|
||||
@ -672,7 +717,7 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
@ -680,11 +725,19 @@ if you are not sure about the structure.
|
||||
else:
|
||||
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
|
||||
|
||||
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
|
||||
def _trigger_invoke_error_callbacks(
|
||||
self,
|
||||
model: str,
|
||||
ex: Exception,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
@ -712,7 +765,7 @@ if you are not sure about the structure.
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
@ -758,11 +811,13 @@ if you are not sure about the structure.
|
||||
# validate parameter value range
|
||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
||||
)
|
||||
|
||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
||||
)
|
||||
elif parameter_rule.type == ParameterType.FLOAT:
|
||||
if not isinstance(parameter_value, float | int):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
||||
@ -775,16 +830,19 @@ if you are not sure about the structure.
|
||||
else:
|
||||
if parameter_value != round(parameter_value, parameter_rule.precision):
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
|
||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
|
||||
)
|
||||
|
||||
# validate parameter value range
|
||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
||||
)
|
||||
|
||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
||||
)
|
||||
elif parameter_rule.type == ParameterType.BOOLEAN:
|
||||
if not isinstance(parameter_value, bool):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
|
||||
|
||||
@ -29,32 +29,32 @@ class ModelProvider(ABC):
|
||||
def get_provider_schema(self) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
|
||||
|
||||
:return: provider schema
|
||||
"""
|
||||
if self.provider_schema:
|
||||
return self.provider_schema
|
||||
|
||||
|
||||
# get dirname of the current path
|
||||
provider_name = self.__class__.__module__.split('.')[-1]
|
||||
provider_name = self.__class__.__module__.split(".")[-1]
|
||||
|
||||
# get the path of the model_provider classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
||||
|
||||
|
||||
# read provider schema from yaml file
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
|
||||
yaml_data = load_yaml_file(yaml_path)
|
||||
|
||||
|
||||
try:
|
||||
# yaml_data to entity
|
||||
provider_schema = ProviderEntity(**yaml_data)
|
||||
except Exception as e:
|
||||
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
|
||||
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
|
||||
|
||||
# cache schema
|
||||
self.provider_schema = provider_schema
|
||||
|
||||
|
||||
return provider_schema
|
||||
|
||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||
@ -92,15 +92,15 @@ class ModelProvider(ABC):
|
||||
|
||||
# get the path of the model type classes
|
||||
base_path = os.path.abspath(__file__)
|
||||
model_type_name = model_type.value.replace('-', '_')
|
||||
model_type_name = model_type.value.replace("-", "_")
|
||||
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
|
||||
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
|
||||
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
|
||||
|
||||
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
||||
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
|
||||
raise Exception(f"Invalid model type {model_type} for provider {provider_name}")
|
||||
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
|
||||
mod = import_module_from_source(
|
||||
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||
)
|
||||
|
||||
@ -12,14 +12,13 @@ class ModerationModel(AIModel):
|
||||
"""
|
||||
Model class for moderation model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
@ -37,9 +36,7 @@ class ModerationModel(AIModel):
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -50,4 +47,3 @@ class ModerationModel(AIModel):
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -11,12 +11,19 @@ class RerankModel(AIModel):
|
||||
"""
|
||||
Base Model class for rerank model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -37,10 +44,16 @@ class RerankModel(AIModel):
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
|
||||
@ -12,14 +12,13 @@ class Speech2TextModel(AIModel):
|
||||
"""
|
||||
Model class for speech2text model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -35,9 +34,7 @@ class Speech2TextModel(AIModel):
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -59,4 +56,4 @@ class Speech2TextModel(AIModel):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Construct the path to the audio file
|
||||
return os.path.join(current_dir, 'audio.mp3')
|
||||
return os.path.join(current_dir, "audio.mp3")
|
||||
|
||||
@ -11,14 +11,15 @@ class Text2ImageModel(AIModel):
|
||||
"""
|
||||
Model class for text2img model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT2IMG
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
def invoke(
|
||||
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
||||
) -> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
@ -36,9 +37,9 @@ class Text2ImageModel(AIModel):
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
||||
) -> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
|
||||
@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel):
|
||||
"""
|
||||
Model class for text embedding model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -38,9 +39,9 @@ class TextEmbeddingModel(AIModel):
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
@ -7,27 +7,28 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text, verbose=False)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
return _tokenizer
|
||||
|
||||
@ -15,13 +15,15 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for TTS model.
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
def invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -35,14 +37,21 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model=model, credentials=credentials, user=user,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
user=user,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -71,10 +80,13 @@ class TTSModel(AIModel):
|
||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [{'name': d['name'], 'value': d['mode']} for d in voices if
|
||||
language and language in d.get('language')]
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]}
|
||||
for d in voices
|
||||
if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{'name': d['name'], 'value': d['mode']} for d in voices]
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
"""
|
||||
@ -123,23 +135,23 @@ class TTSModel(AIModel):
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ''
|
||||
one_sentence = ""
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ''
|
||||
one_sentence = ""
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != '':
|
||||
if one_sentence != "":
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
|
||||
@ -20,12 +20,9 @@ class AnthropicProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `claude-3-opus-20240229` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='claude-3-opus-20240229',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -55,11 +55,17 @@ if you are not sure about the structure.
|
||||
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -76,10 +82,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
# invoke model
|
||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
|
||||
@ -96,41 +109,39 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# transform model parameters from completion api of anthropic to chat api
|
||||
if 'max_tokens_to_sample' in model_parameters:
|
||||
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
|
||||
if "max_tokens_to_sample" in model_parameters:
|
||||
model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample")
|
||||
|
||||
# init model client
|
||||
client = Anthropic(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
extra_model_kwargs["stop_sequences"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
||||
extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user)
|
||||
|
||||
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
||||
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
extra_model_kwargs["system"] = system
|
||||
|
||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
if model_parameters.get('max_tokens') > 4096:
|
||||
if model_parameters.get("max_tokens") > 4096:
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
self._transform_tool_prompt(tool) for tool in tools
|
||||
]
|
||||
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
||||
response = client.beta.tools.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
else:
|
||||
# chat model
|
||||
@ -140,22 +151,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
def _code_block_mode_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if model_parameters.get('response_format'):
|
||||
if model_parameters.get("response_format"):
|
||||
stop = stop or []
|
||||
# chat model
|
||||
self._transform_chat_json_prompts(
|
||||
@ -167,24 +186,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
response_format=model_parameters["response_format"],
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
model_parameters.pop("response_format")
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
|
||||
return {
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
'input_schema': tool.parameters
|
||||
}
|
||||
return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters}
|
||||
|
||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
def _transform_chat_json_prompts(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
@ -197,22 +219,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||
"{{block}}", response_format
|
||||
)
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
prompt_messages.insert(
|
||||
0,
|
||||
SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
|
||||
"{{instructions}}", f"Please output a valid {response_format} object."
|
||||
).replace("{{block}}", response_format)
|
||||
),
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -228,9 +258,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
tokens = client.count_tokens(prompt)
|
||||
|
||||
tool_call_inner_prompts_tokens_map = {
|
||||
'claude-3-opus-20240229': 395,
|
||||
'claude-3-haiku-20240307': 264,
|
||||
'claude-3-sonnet-20240229': 159
|
||||
"claude-3-opus-20240229": 395,
|
||||
"claude-3-haiku-20240307": 264,
|
||||
"claude-3-sonnet-20240229": 159,
|
||||
}
|
||||
|
||||
if model in tool_call_inner_prompts_tokens_map and tools:
|
||||
@ -257,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"temperature": 0,
|
||||
"max_tokens": 20,
|
||||
},
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Union[Message, ToolsBetaMessage],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -274,22 +309,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
|
||||
for content in response.content:
|
||||
if content.type == 'text':
|
||||
if content.type == "text":
|
||||
assistant_prompt_message.content += content.text
|
||||
elif content.type == 'tool_use':
|
||||
elif content.type == "tool_use":
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=content.id,
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=content.name,
|
||||
arguments=json.dumps(content.input)
|
||||
)
|
||||
name=content.name, arguments=json.dumps(content.input)
|
||||
),
|
||||
)
|
||||
assistant_prompt_message.tool_calls.append(tool_call)
|
||||
|
||||
@ -308,17 +339,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_chat_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
@ -327,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
@ -338,24 +366,23 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, MessageStartEvent):
|
||||
if hasattr(chunk, 'content_block'):
|
||||
if hasattr(chunk, "content_block"):
|
||||
content_block = chunk.content_block
|
||||
if isinstance(content_block, dict):
|
||||
if content_block.get('type') == 'tool_use':
|
||||
if content_block.get("type") == "tool_use":
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=content_block.get('id'),
|
||||
type='function',
|
||||
id=content_block.get("id"),
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=content_block.get('name'),
|
||||
arguments=''
|
||||
)
|
||||
name=content_block.get("name"), arguments=""
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif hasattr(chunk, 'delta'):
|
||||
elif hasattr(chunk, "delta"):
|
||||
delta = chunk.delta
|
||||
if isinstance(delta, dict) and len(tool_calls) > 0:
|
||||
if delta.get('type') == 'input_json_delta':
|
||||
tool_calls[-1].function.arguments += delta.get('partial_json', '')
|
||||
if delta.get("type") == "input_json_delta":
|
||||
tool_calls[-1].function.arguments += delta.get("partial_json", "")
|
||||
elif chunk.message:
|
||||
return_model = chunk.message.model
|
||||
input_tokens = chunk.message.usage.input_tokens
|
||||
@ -369,29 +396,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
# transform empty tool call arguments to {}
|
||||
for tool_call in tool_calls:
|
||||
if not tool_call.function.arguments:
|
||||
tool_call.function.arguments = '{}'
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
message=AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
||||
full_assistant_content += chunk_text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
|
||||
|
||||
index = chunk.index
|
||||
|
||||
@ -401,7 +423,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
@ -412,14 +434,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['anthropic_api_key'],
|
||||
"api_key": credentials["anthropic_api_key"],
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if credentials.get('anthropic_api_url'):
|
||||
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
|
||||
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
|
||||
if credentials.get("anthropic_api_url"):
|
||||
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/")
|
||||
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@ -452,10 +474,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
@ -465,25 +484,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
)
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
||||
@ -492,34 +511,28 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
content = []
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
content.append({
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
if message.content:
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
})
|
||||
|
||||
content.append({"type": "text", "text": message.content})
|
||||
|
||||
if prompt_message_dicts[-1]["role"] == "assistant":
|
||||
prompt_message_dicts[-1]["content"].extend(content)
|
||||
else:
|
||||
prompt_message_dicts.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
prompt_message_dicts.append({"role": "assistant", "content": content})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}
|
||||
],
|
||||
}
|
||||
prompt_message_dicts.append(message_dict)
|
||||
else:
|
||||
@ -576,16 +589,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
@ -601,24 +611,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
anthropic.APIConnectionError,
|
||||
anthropic.APITimeoutError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
anthropic.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
anthropic.RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
anthropic.AuthenticationError,
|
||||
anthropic.PermissionDeniedError
|
||||
],
|
||||
InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError],
|
||||
InvokeServerUnavailableError: [anthropic.InternalServerError],
|
||||
InvokeRateLimitError: [anthropic.RateLimitError],
|
||||
InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError],
|
||||
InvokeBadRequestError: [
|
||||
anthropic.BadRequestError,
|
||||
anthropic.NotFoundError,
|
||||
anthropic.UnprocessableEntityError,
|
||||
anthropic.APIError
|
||||
]
|
||||
anthropic.APIError,
|
||||
],
|
||||
}
|
||||
|
||||
@ -15,10 +15,10 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN
|
||||
class _CommonAzureOpenAI:
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
|
||||
api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION)
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['openai_api_key'],
|
||||
"azure_endpoint": credentials['openai_api_base'],
|
||||
"api_key": credentials["openai_api_key"],
|
||||
"azure_endpoint": credentials["openai_api_base"],
|
||||
"api_version": api_version,
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
@ -29,24 +29,14 @@ class _CommonAzureOpenAI:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
openai.APIConnectionError,
|
||||
openai.APITimeoutError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
openai.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
openai.RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
openai.AuthenticationError,
|
||||
openai.PermissionDeniedError
|
||||
],
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError
|
||||
]
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
|
||||
@ -34,16 +34,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
raise ValueError("Base Model Name is required")
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
@ -56,7 +60,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# text completion model
|
||||
@ -67,7 +71,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
@ -75,14 +79,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
raise ValueError("Base Model Name is required")
|
||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
if not model_entity:
|
||||
raise ValueError(f'Base Model Name {base_model_name} is invalid')
|
||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
||||
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
@ -92,21 +96,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
# text completion model, do not support tool calling
|
||||
content = prompt_messages[0].content
|
||||
assert isinstance(content, str)
|
||||
return self._num_tokens_from_string(credentials,content)
|
||||
return self._num_tokens_from_string(credentials, content)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
if "openai_api_base" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
||||
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
if "openai_api_key" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
@ -118,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": 'ping'}],
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
@ -127,7 +131,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
else:
|
||||
# text completion model
|
||||
client.completions.create(
|
||||
prompt='ping',
|
||||
prompt="ping",
|
||||
model=model,
|
||||
temperature=0,
|
||||
max_tokens=20,
|
||||
@ -137,33 +141,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
base_model_name = credentials.get('base_model_name')
|
||||
base_model_name = credentials.get("base_model_name")
|
||||
if not base_model_name:
|
||||
raise ValueError('Base Model Name is required')
|
||||
raise ValueError("Base Model Name is required")
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -172,15 +178,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
@ -209,24 +212,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.text is None or delta.text == ''):
|
||||
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = delta.text if delta.text else ''
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
text = delta.text if delta.text else ""
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
||||
@ -254,8 +254,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
@ -265,14 +265,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
@ -293,7 +299,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||
# extra_model_kwargs['functions'] = [{
|
||||
# "name": tool.name,
|
||||
# "description": tool.description,
|
||||
@ -301,10 +307,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
# } for tool in tools]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# chat model
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
@ -322,9 +328,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self, model: str, credentials: dict, response: ChatCompletion,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
assistant_message = response.choices[0].message
|
||||
assistant_message_tool_calls = assistant_message.tool_calls
|
||||
@ -334,10 +343,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message.content,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
@ -369,13 +375,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
):
|
||||
index = 0
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
real_model = model
|
||||
system_fingerprint = None
|
||||
completion = ''
|
||||
completion = ""
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
@ -386,7 +392,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
if delta.delta is None:
|
||||
continue
|
||||
|
||||
|
||||
# extract tool calls from response
|
||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
||||
|
||||
@ -396,15 +401,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=tool_calls
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
||||
)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||
|
||||
real_model = chunk.model
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
completion += delta.delta.content if delta.delta.content else ''
|
||||
completion += delta.delta.content if delta.delta.content else ""
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=real_model,
|
||||
@ -413,7 +417,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
index += 0
|
||||
@ -421,9 +425,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=completion
|
||||
)
|
||||
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
|
||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
@ -434,27 +436,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
finish_reason='stop',
|
||||
usage=usage
|
||||
)
|
||||
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
|
||||
def _update_tool_calls(
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
|
||||
) -> None:
|
||||
if tool_calls_response:
|
||||
for response_tool_call in tool_calls_response:
|
||||
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
||||
@ -463,8 +462,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
||||
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
||||
if response_tool_call.function:
|
||||
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
|
||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
|
||||
tool_calls[index].function.name = (
|
||||
response_tool_call.function.name or tool_calls[index].function.name
|
||||
)
|
||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
|
||||
else:
|
||||
assert response_tool_call.id is not None
|
||||
assert response_tool_call.type is not None
|
||||
@ -473,13 +474,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
assert response_tool_call.function.arguments is not None
|
||||
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name,
|
||||
arguments=response_tool_call.function.arguments
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id,
|
||||
type=response_tool_call.type,
|
||||
function=function
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
@ -495,19 +493,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
@ -525,7 +517,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
"role": "tool",
|
||||
"name": message.name,
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
@ -535,10 +527,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(self, credentials: dict, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def _num_tokens_from_string(
|
||||
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(credentials['base_model_name'])
|
||||
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@ -550,14 +543,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, credentials: dict, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None
|
||||
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model = credentials['base_model_name']
|
||||
model = credentials["base_model_name"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
@ -591,10 +583,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
@ -626,40 +618,39 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
||||
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode('function'))
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode("function"))
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode("name"))
|
||||
num_tokens += len(encoding.encode(tool.name))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode("description"))
|
||||
num_tokens += len(encoding.encode(tool.description))
|
||||
parameters = tool.parameters
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters['title']))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters['type']))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters['properties'].items():
|
||||
num_tokens += len(encoding.encode("parameters"))
|
||||
if "title" in parameters:
|
||||
num_tokens += len(encoding.encode("title"))
|
||||
num_tokens += len(encoding.encode(parameters["title"]))
|
||||
num_tokens += len(encoding.encode("type"))
|
||||
num_tokens += len(encoding.encode(parameters["type"]))
|
||||
if "properties" in parameters:
|
||||
num_tokens += len(encoding.encode("properties"))
|
||||
for key, value in parameters["properties"].items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
if "required" in parameters:
|
||||
num_tokens += len(encoding.encode("required"))
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
|
||||
@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
@ -40,7 +38,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
self._speech2text_invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -65,10 +63,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
return response.text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||
|
||||
@ -16,19 +16,18 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_
|
||||
|
||||
|
||||
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
base_model_name = credentials['base_model_name']
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
base_model_name = credentials["base_model_name"]
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs['encoding_format'] = 'base64'
|
||||
extra_model_kwargs["encoding_format"] = "base64"
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
@ -44,11 +43,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
token = enc.encode(
|
||||
text
|
||||
)
|
||||
token = enc.encode(text)
|
||||
for j in range(0, len(token), context_size):
|
||||
tokens += [token[j: j + context_size]]
|
||||
tokens += [token[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
@ -56,10 +53,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
|
||||
for i in _iter:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=tokens[i: i + max_chunks],
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
@ -75,10 +69,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts="",
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
@ -88,24 +79,16 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
model=base_model_name
|
||||
)
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(credentials['base_model_name'])
|
||||
enc = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
except KeyError:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@ -118,57 +101,52 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
if "openai_api_base" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
||||
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
if "openai_api_key" not in credentials:
|
||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
if "base_model_name" not in credentials:
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
if not self._get_ai_model_entity(credentials['base_model_name'], model):
|
||||
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=['ping'],
|
||||
extra_model_kwargs={}
|
||||
)
|
||||
self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
|
||||
extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
|
||||
def _embedding_invoke(
|
||||
model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
||||
) -> tuple[list[list[float]], int]:
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
model=model,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
|
||||
if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64":
|
||||
# decode base64 embedding
|
||||
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
||||
response.usage.total_tokens)
|
||||
return (
|
||||
[list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
||||
response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -179,7 +157,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@ -30,13 +31,12 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
if not voice or voice not in [
|
||||
d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials)
|
||||
]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -50,14 +50,13 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello Dify!',
|
||||
content_text="Hello Dify!",
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
:param model: model name
|
||||
@ -75,23 +74,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
futures = [
|
||||
executor.submit(
|
||||
client.audio.speech.with_streaming_response.create,
|
||||
model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i],
|
||||
voice=voice,
|
||||
)
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
response = client.audio.speech.with_streaming_response.create(
|
||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||
)
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str,
|
||||
voice, credentials: dict):
|
||||
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
|
||||
@ -108,10 +113,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
return response.read()
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
|
||||
@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaichuanProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
@ -19,12 +20,9 @@ class BaichuanProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `baichuan2-turbo` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -4,17 +4,17 @@ import re
|
||||
class BaichuanTokenizer:
|
||||
@classmethod
|
||||
def count_chinese_characters(cls, text: str) -> int:
|
||||
return len(re.findall(r'[\u4e00-\u9fa5]', text))
|
||||
return len(re.findall(r"[\u4e00-\u9fa5]", text))
|
||||
|
||||
@classmethod
|
||||
def count_english_vocabularies(cls, text: str) -> int:
|
||||
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
|
||||
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
||||
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
|
||||
# count the number of words not characters
|
||||
return len(text.split())
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens(cls, text: str) -> int:
|
||||
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
||||
# https://platform.baichuan-ai.com/docs/text-Embedding
|
||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
||||
|
||||
@ -94,7 +94,6 @@ class BaichuanModel:
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
|
||||
if model in self._model_mapping.keys():
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
@ -120,9 +119,7 @@ class BaichuanModel:
|
||||
err = resp["error"]["type"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
|
||||
@ -1,17 +1,22 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -38,17 +38,16 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
||||
|
||||
|
||||
class BaichuanLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(
|
||||
model=model,
|
||||
@ -60,17 +59,17 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
@ -111,18 +110,13 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
@ -146,15 +140,14 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stream: bool = True,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stream: bool = True,
|
||||
) -> LLMResult | Generator:
|
||||
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
|
||||
@ -169,23 +162,19 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
|
||||
return self._handle_chat_generate_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
) -> LLMResult:
|
||||
choices = response.get("choices", [])
|
||||
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if choices and choices[0]["finish_reason"] == "tool_calls":
|
||||
for choice in choices:
|
||||
for tool_call in choice["message"]["tool_calls"]:
|
||||
@ -194,7 +183,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
type=tool_call.get("type", ""),
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.get("function", {}).get("name", ""),
|
||||
arguments=tool_call.get("function", {}).get("arguments", "")
|
||||
arguments=tool_call.get("function", {}).get("arguments", ""),
|
||||
),
|
||||
)
|
||||
assistant_message.tool_calls.append(tool)
|
||||
@ -228,11 +217,11 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Iterator,
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Iterator,
|
||||
) -> Generator:
|
||||
for line in response:
|
||||
if not line:
|
||||
@ -260,9 +249,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=choice["delta"]["content"], tool_calls=[]
|
||||
),
|
||||
message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]),
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
|
||||
@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for BaiChuan text embedding model.
|
||||
"""
|
||||
api_base: str = 'http://api.baichuan-ai.com/v1/embeddings'
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -45,28 +46,23 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['api_key']
|
||||
if model != 'baichuan-text-embedding':
|
||||
raise ValueError('Invalid model name')
|
||||
api_key = credentials["api_key"]
|
||||
if model != "baichuan-text-embedding":
|
||||
raise ValueError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
# split into chunks of batch size 16
|
||||
chunks = []
|
||||
for i in range(0, len(texts), 16):
|
||||
chunks.append(texts[i:i + 16])
|
||||
chunks.append(texts[i : i + 16])
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for chunk in chunks:
|
||||
# embedding chunk
|
||||
chunk_embeddings, chunk_usage = self.embedding(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
texts=chunk,
|
||||
user=user
|
||||
)
|
||||
chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user)
|
||||
|
||||
embeddings.extend(chunk_embeddings)
|
||||
token_usage += chunk_usage
|
||||
@ -74,17 +70,14 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \
|
||||
-> tuple[list[list[float]], int]:
|
||||
|
||||
def embedding(
|
||||
self, model: str, api_key, texts: list[str], user: Optional[str] = None
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Embed given texts
|
||||
|
||||
@ -95,56 +88,47 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
url = self.api_base
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
'model': 'Baichuan-Text-Embedding',
|
||||
'input': texts
|
||||
}
|
||||
data = {"model": "Baichuan-Text-Embedding", "input": texts}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
err = resp["error"]["code"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and 'rate' in err:
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and "rate" in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif err and 'internal' in err:
|
||||
elif err and "internal" in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
elif err == "api_key_empty":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
return [
|
||||
data['embedding'] for data in embeddings
|
||||
], usage['total_tokens']
|
||||
|
||||
return [data["embedding"] for data in embeddings], usage["total_tokens"]
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
@ -170,32 +154,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvalidAPIKeyError:
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -207,10 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -221,7 +194,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BedrockProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
@ -19,13 +20,10 @@ class BedrockProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `amazon.titan-text-lite-v1` model by default for validating credentials
|
||||
model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1')
|
||||
model_instance.validate_credentials(
|
||||
model=model_for_validation,
|
||||
credentials=credentials
|
||||
)
|
||||
model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1")
|
||||
model_instance.validate_credentials(model=model_for_validation, credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -45,36 +45,42 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
|
||||
CONVERSE_API_ENABLED_MODEL_INFO=[
|
||||
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
||||
CONVERSE_API_ENABLED_MODEL_INFO = [
|
||||
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
|
||||
{"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _find_model_info(model_id):
|
||||
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
|
||||
if model_id.startswith(model['prefix']):
|
||||
if model_id.startswith(model["prefix"]):
|
||||
return model
|
||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||
return None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -88,17 +94,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
model_info= BedrockLargeLanguageModel._find_model_info(model)
|
||||
|
||||
model_info = BedrockLargeLanguageModel._find_model_info(model)
|
||||
if model_info:
|
||||
model_info['model'] = model
|
||||
model_info["model"] = model
|
||||
# invoke models via boto3 converse API
|
||||
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
return self._generate_with_converse(
|
||||
model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools
|
||||
)
|
||||
# invoke other models via boto3 client
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||
def _generate_with_converse(
|
||||
self,
|
||||
model_info: dict,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model with converse API
|
||||
|
||||
@ -110,35 +127,39 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
bedrock_client = boto3.client(service_name='bedrock-runtime',
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
region_name=credentials["aws_region"])
|
||||
bedrock_client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
region_name=credentials["aws_region"],
|
||||
)
|
||||
|
||||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
parameters = {
|
||||
'modelId': model_info['model'],
|
||||
'messages': prompt_message_dicts,
|
||||
'inferenceConfig': inference_config,
|
||||
'additionalModelRequestFields': additional_model_fields,
|
||||
"modelId": model_info["model"],
|
||||
"messages": prompt_message_dicts,
|
||||
"inferenceConfig": inference_config,
|
||||
"additionalModelRequestFields": additional_model_fields,
|
||||
}
|
||||
|
||||
if model_info['support_system_prompts'] and system and len(system) > 0:
|
||||
parameters['system'] = system
|
||||
if model_info["support_system_prompts"] and system and len(system) > 0:
|
||||
parameters["system"] = system
|
||||
|
||||
if model_info['support_tool_use'] and tools:
|
||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||
if model_info["support_tool_use"] and tools:
|
||||
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
|
||||
try:
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
|
||||
return self._handle_converse_stream_response(
|
||||
model_info["model"], credentials, response, prompt_messages
|
||||
)
|
||||
else:
|
||||
response = bedrock_client.converse(**parameters)
|
||||
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
|
||||
return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages)
|
||||
except ClientError as ex:
|
||||
error_code = ex.response['Error']['Code']
|
||||
error_code = ex.response["Error"]["Code"]
|
||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||
@ -149,8 +170,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
|
||||
def _handle_converse_response(
|
||||
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -160,36 +183,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
response_content = response['output']['message']['content']
|
||||
response_content = response["output"]["message"]["content"]
|
||||
# transform assistant message to prompt message
|
||||
if response['stopReason'] == 'tool_use':
|
||||
if response["stopReason"] == "tool_use":
|
||||
tool_calls = []
|
||||
text, tool_use = self._extract_tool_use(response_content)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
id=tool_use["toolUseId"],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=json.dumps(tool_use['input'])
|
||||
)
|
||||
name=tool_use["name"], arguments=json.dumps(tool_use["input"])
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls)
|
||||
else:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response_content[0]['text']
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"])
|
||||
|
||||
# calculate num tokens
|
||||
if response['usage']:
|
||||
if response["usage"]:
|
||||
# transform usage
|
||||
prompt_tokens = response['usage']['inputTokens']
|
||||
completion_tokens = response['usage']['outputTokens']
|
||||
prompt_tokens = response["usage"]["inputTokens"]
|
||||
completion_tokens = response["usage"]["outputTokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
@ -206,20 +223,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
return result
|
||||
|
||||
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
|
||||
def _extract_tool_use(self, content: dict) -> tuple[str, dict]:
|
||||
tool_use = {}
|
||||
text = ''
|
||||
text = ""
|
||||
for item in content:
|
||||
if 'toolUse' in item:
|
||||
tool_use = item['toolUse']
|
||||
elif 'text' in item:
|
||||
text = item['text']
|
||||
if "toolUse" in item:
|
||||
tool_use = item["toolUse"]
|
||||
elif "text" in item:
|
||||
text = item["text"]
|
||||
else:
|
||||
raise ValueError(f"Got unknown item: {item}")
|
||||
return text, tool_use
|
||||
|
||||
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||
def _handle_converse_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
@ -231,7 +253,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
try:
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
@ -240,87 +262,85 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_use = {}
|
||||
|
||||
for chunk in response['stream']:
|
||||
if 'messageStart' in chunk:
|
||||
for chunk in response["stream"]:
|
||||
if "messageStart" in chunk:
|
||||
return_model = model
|
||||
elif 'messageStop' in chunk:
|
||||
finish_reason = chunk['messageStop']['stopReason']
|
||||
elif 'contentBlockStart' in chunk:
|
||||
tool = chunk['contentBlockStart']['start']['toolUse']
|
||||
tool_use['toolUseId'] = tool['toolUseId']
|
||||
tool_use['name'] = tool['name']
|
||||
elif 'metadata' in chunk:
|
||||
input_tokens = chunk['metadata']['usage']['inputTokens']
|
||||
output_tokens = chunk['metadata']['usage']['outputTokens']
|
||||
elif "messageStop" in chunk:
|
||||
finish_reason = chunk["messageStop"]["stopReason"]
|
||||
elif "contentBlockStart" in chunk:
|
||||
tool = chunk["contentBlockStart"]["start"]["toolUse"]
|
||||
tool_use["toolUseId"] = tool["toolUseId"]
|
||||
tool_use["name"] = tool["name"]
|
||||
elif "metadata" in chunk:
|
||||
input_tokens = chunk["metadata"]["usage"]["inputTokens"]
|
||||
output_tokens = chunk["metadata"]["usage"]["outputTokens"]
|
||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
elif 'contentBlockDelta' in chunk:
|
||||
delta = chunk['contentBlockDelta']['delta']
|
||||
if 'text' in delta:
|
||||
chunk_text = delta['text'] if delta['text'] else ''
|
||||
elif "contentBlockDelta" in chunk:
|
||||
delta = chunk["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
chunk_text = delta["text"] if delta["text"] else ""
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
content=chunk_text if chunk_text else "",
|
||||
)
|
||||
index = chunk['contentBlockDelta']['contentBlockIndex']
|
||||
index = chunk["contentBlockDelta"]["contentBlockIndex"]
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index+1,
|
||||
index=index + 1,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif 'toolUse' in delta:
|
||||
if 'input' not in tool_use:
|
||||
tool_use['input'] = ''
|
||||
tool_use['input'] += delta['toolUse']['input']
|
||||
elif 'contentBlockStop' in chunk:
|
||||
if 'input' in tool_use:
|
||||
elif "toolUse" in delta:
|
||||
if "input" not in tool_use:
|
||||
tool_use["input"] = ""
|
||||
tool_use["input"] += delta["toolUse"]["input"]
|
||||
elif "contentBlockStop" in chunk:
|
||||
if "input" in tool_use:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
id=tool_use["toolUseId"],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=tool_use['input']
|
||||
)
|
||||
name=tool_use["name"], arguments=tool_use["input"]
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
tool_use = {}
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
|
||||
|
||||
def _convert_converse_api_model_parameters(
|
||||
self, model_parameters: dict, stop: Optional[list[str]] = None
|
||||
) -> tuple[dict, dict]:
|
||||
inference_config = {}
|
||||
additional_model_fields = {}
|
||||
if 'max_tokens' in model_parameters:
|
||||
inference_config['maxTokens'] = model_parameters['max_tokens']
|
||||
if "max_tokens" in model_parameters:
|
||||
inference_config["maxTokens"] = model_parameters["max_tokens"]
|
||||
|
||||
if 'temperature' in model_parameters:
|
||||
inference_config['temperature'] = model_parameters['temperature']
|
||||
|
||||
if 'top_p' in model_parameters:
|
||||
inference_config['topP'] = model_parameters['temperature']
|
||||
if "temperature" in model_parameters:
|
||||
inference_config["temperature"] = model_parameters["temperature"]
|
||||
|
||||
if "top_p" in model_parameters:
|
||||
inference_config["topP"] = model_parameters["temperature"]
|
||||
|
||||
if stop:
|
||||
inference_config['stopSequences'] = stop
|
||||
|
||||
if 'top_k' in model_parameters:
|
||||
additional_model_fields['top_k'] = model_parameters['top_k']
|
||||
|
||||
inference_config["stopSequences"] = stop
|
||||
|
||||
if "top_k" in model_parameters:
|
||||
additional_model_fields["top_k"] = model_parameters["top_k"]
|
||||
|
||||
return inference_config, additional_model_fields
|
||||
|
||||
def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
@ -332,7 +352,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
message.content = message.content.strip()
|
||||
system.append({"text": message.content})
|
||||
else:
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
@ -349,15 +369,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"toolSpec": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": {
|
||||
"json": tool.parameters
|
||||
}
|
||||
"inputSchema": {"json": tool.parameters},
|
||||
}
|
||||
}
|
||||
)
|
||||
tool_config["tools"] = configs
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
@ -365,15 +383,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": [{'text': message.content}]}
|
||||
message_dict = {"role": "user", "content": [{"text": message.content}]}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_message_dict = {"text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
@ -384,7 +400,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
image_content = requests.get(url).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
@ -394,16 +410,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
)
|
||||
|
||||
sub_message_dict = {
|
||||
"image": {
|
||||
"format": mime_type.replace('image/', ''),
|
||||
"source": {
|
||||
"bytes": image_content
|
||||
}
|
||||
}
|
||||
"image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
@ -412,36 +425,46 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
if message.tool_calls:
|
||||
message_dict = {
|
||||
"role": "assistant", "content":[{
|
||||
"toolUse": {
|
||||
"toolUseId": message.tool_calls[0].id,
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"input": json.loads(message.tool_calls[0].function.arguments)
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": message.tool_calls[0].id,
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"input": json.loads(message.tool_calls[0].function.arguments),
|
||||
}
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
else:
|
||||
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
||||
message_dict = {"role": "assistant", "content": [{"text": message.content}]}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = [{'text': message.content}]
|
||||
message_dict = [{"text": message.content}]
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"toolResult": {
|
||||
"toolUseId": message.tool_call_id,
|
||||
"content": [{"json": {"text": message.content}}]
|
||||
}
|
||||
}]
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.tool_call_id,
|
||||
"content": [{"json": {"text": message.content}}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -451,15 +474,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
prefix = model.split(".")[0]
|
||||
model_name = model.split(".")[1]
|
||||
|
||||
if isinstance(prompt_messages, str):
|
||||
prompt = prompt_messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
|
||||
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@ -482,24 +504,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"topP": 0.9,
|
||||
"maxTokens": 32,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._invoke(model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[ping_message],
|
||||
model_parameters=required_params,
|
||||
stream=False)
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[ping_message],
|
||||
model_parameters=required_params,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
except ClientError as ex:
|
||||
error_code = ex.response['Error']['Code']
|
||||
error_code = ex.response["Error"]["Code"]
|
||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||
raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
def _convert_one_message_to_text(
|
||||
self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
@ -514,7 +540,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
body = content
|
||||
if (isinstance(content, list)):
|
||||
if isinstance(content, list):
|
||||
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
|
||||
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
@ -528,7 +554,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
def _convert_messages_to_prompt(
|
||||
self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||
|
||||
@ -537,27 +565,31 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix, model_name)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
def _create_payload(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = {}
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
model_prefix = model.split(".")[0]
|
||||
model_name = model.split(".")[1]
|
||||
|
||||
if model_prefix == "ai21":
|
||||
payload["temperature"] = model_parameters.get("temperature")
|
||||
@ -571,21 +603,27 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
|
||||
if model_parameters.get("countPenalty"):
|
||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||
|
||||
|
||||
elif model_prefix == "cohere":
|
||||
payload = { **model_parameters }
|
||||
payload = {**model_parameters}
|
||||
payload["prompt"] = prompt_messages[0].content
|
||||
payload["stream"] = stream
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||
|
||||
|
||||
return payload
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -598,18 +636,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
client_config = Config(
|
||||
region_name=credentials["aws_region"]
|
||||
)
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
runtime_client = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
model_prefix = model.split(".")[0]
|
||||
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
|
||||
|
||||
# need workaround for ai21 models which doesn't support streaming
|
||||
@ -619,18 +655,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
invoke = runtime_client.invoke_model
|
||||
|
||||
try:
|
||||
body_jsonstr=json.dumps(payload)
|
||||
response = invoke(
|
||||
modelId=model,
|
||||
contentType="application/json",
|
||||
accept= "*/*",
|
||||
body=body_jsonstr
|
||||
)
|
||||
body_jsonstr = json.dumps(payload)
|
||||
response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr)
|
||||
except ClientError as ex:
|
||||
error_code = ex.response['Error']['Code']
|
||||
error_code = ex.response["Error"]["Code"]
|
||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||
|
||||
|
||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||
raise InvokeConnectionError(str(ex))
|
||||
|
||||
@ -639,15 +670,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
@ -657,7 +688,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||
|
||||
finish_reason = response_body.get("error")
|
||||
|
||||
@ -665,25 +696,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
raise InvokeError(finish_reason)
|
||||
|
||||
# get output text and calculate num tokens based on model / provider
|
||||
model_prefix = model.split('.')[0]
|
||||
model_prefix = model.split(".")[0]
|
||||
|
||||
if model_prefix == "ai21":
|
||||
output = response_body.get('completions')[0].get('data').get('text')
|
||||
output = response_body.get("completions")[0].get("data").get("text")
|
||||
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
||||
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
||||
|
||||
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
|
||||
|
||||
elif model_prefix == "cohere":
|
||||
output = response_body.get("generations")[0].get("text")
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
||||
|
||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
# construct assistant message from output
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=output
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=output)
|
||||
|
||||
# calculate usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@ -698,8 +727,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@ -709,65 +739,59 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
model_prefix = model.split('.')[0]
|
||||
model_prefix = model.split(".")[0]
|
||||
if model_prefix == "ai21":
|
||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||
|
||||
content = response_body.get('completions')[0].get('data').get('text')
|
||||
finish_reason = response_body.get('completions')[0].get('finish_reason')
|
||||
content = response_body.get("completions")[0].get("data").get("text")
|
||||
finish_reason = response_body.get("completions")[0].get("finish_reason")
|
||||
|
||||
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
||||
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
||||
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=content),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
stream = response.get('body')
|
||||
|
||||
stream = response.get("body")
|
||||
if not stream:
|
||||
raise InvokeError('No response body')
|
||||
|
||||
raise InvokeError("No response body")
|
||||
|
||||
index = -1
|
||||
for event in stream:
|
||||
chunk = event.get('chunk')
|
||||
|
||||
chunk = event.get("chunk")
|
||||
|
||||
if not chunk:
|
||||
exception_name = next(iter(event))
|
||||
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
|
||||
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
|
||||
|
||||
payload = json.loads(chunk.get('bytes').decode())
|
||||
payload = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
model_prefix = model.split(".")[0]
|
||||
if model_prefix == "cohere":
|
||||
content_delta = payload.get("text")
|
||||
finish_reason = payload.get("finish_reason")
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content = content_delta if content_delta else '',
|
||||
content=content_delta if content_delta else "",
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
if not finish_reason:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
|
||||
else:
|
||||
@ -777,18 +801,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
@ -804,9 +825,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
|
||||
|
||||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
||||
"""
|
||||
Map client error to invoke error
|
||||
@ -822,7 +843,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
return InvokeBadRequestError(error_msg)
|
||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||
return InvokeRateLimitError(error_msg)
|
||||
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
|
||||
elif error_code in [
|
||||
"ModelTimeoutException",
|
||||
"ModelErrorException",
|
||||
"InternalServerException",
|
||||
"ModelNotReadyException",
|
||||
]:
|
||||
return InvokeServerUnavailableError(error_msg)
|
||||
elif error_code == "ModelStreamErrorException":
|
||||
return InvokeConnectionError(error_msg)
|
||||
|
||||
@ -27,12 +27,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -42,67 +41,56 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
client_config = Config(
|
||||
region_name=credentials["aws_region"]
|
||||
)
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
|
||||
if model_prefix == "amazon" :
|
||||
|
||||
model_prefix = model.split(".")[0]
|
||||
|
||||
if model_prefix == "amazon":
|
||||
for text in texts:
|
||||
body = {
|
||||
"inputText": text,
|
||||
"inputText": text,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend([response_body.get('embedding')])
|
||||
token_usage += response_body.get('inputTextTokenCount')
|
||||
logger.warning(f'Total Tokens: {token_usage}')
|
||||
embeddings.extend([response_body.get("embedding")])
|
||||
token_usage += response_body.get("inputTextTokenCount")
|
||||
logger.warning(f"Total Tokens: {token_usage}")
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
return result
|
||||
|
||||
if model_prefix == "cohere" :
|
||||
input_type = 'search_document' if len(texts) > 1 else 'search_query'
|
||||
if model_prefix == "cohere":
|
||||
input_type = "search_document" if len(texts) > 1 else "search_query"
|
||||
for text in texts:
|
||||
body = {
|
||||
"texts": [text],
|
||||
"input_type": input_type,
|
||||
"texts": [text],
|
||||
"input_type": input_type,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend(response_body.get('embeddings'))
|
||||
embeddings.extend(response_body.get("embeddings"))
|
||||
token_usage += len(text)
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
return result
|
||||
|
||||
#others
|
||||
# others
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -125,7 +113,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
@ -141,19 +129,25 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
|
||||
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
|
||||
def _create_payload(
|
||||
self,
|
||||
model_prefix: str,
|
||||
texts: list[str],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = {}
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload['inputText'] = texts
|
||||
payload["inputText"] = texts
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -165,10 +159,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -179,7 +170,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -199,31 +190,37 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
return InvokeBadRequestError(error_msg)
|
||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||
return InvokeRateLimitError(error_msg)
|
||||
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
|
||||
elif error_code in [
|
||||
"ModelTimeoutException",
|
||||
"ModelErrorException",
|
||||
"InternalServerException",
|
||||
"ModelNotReadyException",
|
||||
]:
|
||||
return InvokeServerUnavailableError(error_msg)
|
||||
elif error_code == "ModelStreamErrorException":
|
||||
return InvokeConnectionError(error_msg)
|
||||
|
||||
return InvokeError(error_msg)
|
||||
|
||||
|
||||
def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
|
||||
accept = 'application/json'
|
||||
content_type = 'application/json'
|
||||
def _invoke_bedrock_embedding(
|
||||
self,
|
||||
model: str,
|
||||
bedrock_runtime,
|
||||
body: dict,
|
||||
):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
try:
|
||||
response = bedrock_runtime.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=content_type
|
||||
body=json.dumps(body), modelId=model, accept=accept, contentType=content_type
|
||||
)
|
||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||
return response_body
|
||||
except ClientError as ex:
|
||||
error_code = ex.response['Error']['Code']
|
||||
error_code = ex.response["Error"]["Code"]
|
||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||
|
||||
|
||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||
raise InvokeConnectionError(str(ex))
|
||||
|
||||
|
||||
@ -20,12 +20,9 @@ class ChatGLMProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `chatglm3-6b` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='chatglm3-6b',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -43,12 +43,19 @@ from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -71,11 +78,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -96,11 +108,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
||||
UserPromptMessage(content="ping"),
|
||||
], model_parameters={
|
||||
"max_tokens": 16,
|
||||
})
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(content="ping"),
|
||||
],
|
||||
model_parameters={
|
||||
"max_tokens": 16,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
@ -124,24 +141,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
ConflictError,
|
||||
NotFoundError,
|
||||
UnprocessableEntityError,
|
||||
PermissionDeniedError
|
||||
PermissionDeniedError,
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
AuthenticationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
ValueError
|
||||
]
|
||||
InvokeRateLimitError: [RateLimitError],
|
||||
InvokeAuthorizationError: [AuthenticationError],
|
||||
InvokeBadRequestError: [ValueError],
|
||||
}
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -163,35 +180,31 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
extra_model_kwargs['functions'] = [
|
||||
helper.dump_model(tool) for tool in tools
|
||||
]
|
||||
extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools]
|
||||
|
||||
result = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
|
||||
return self._handle_chat_generate_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
|
||||
def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None:
|
||||
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
|
||||
@ -212,7 +225,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments
|
||||
"arguments": message.tool_calls[0].function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
@ -223,12 +236,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self,
|
||||
response_function_calls: list[FunctionCall]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self, response_function_calls: list[FunctionCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -239,19 +252,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
if response_function_calls:
|
||||
for response_tool_call in response_function_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.name,
|
||||
arguments=response_tool_call.arguments
|
||||
name=response_tool_call.name, arguments=response_tool_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=0,
|
||||
type='function',
|
||||
function=function
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Convert invoke kwargs to client kwargs
|
||||
@ -265,17 +273,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": str(URL(credentials['api_base']) / 'v1')
|
||||
"base_url": str(URL(credentials["api_base"]) / "v1"),
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) \
|
||||
-> Generator:
|
||||
|
||||
full_response = ''
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
@ -283,9 +294,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
|
||||
continue
|
||||
|
||||
|
||||
# check if there is a tool call in the response
|
||||
function_calls = None
|
||||
if delta.delta.function_call:
|
||||
@ -295,23 +306,25 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response,
|
||||
tool_calls=assistant_message_tool_calls
|
||||
content=full_response, tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -320,7 +333,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -335,11 +348,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
full_response += delta.delta.content
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) \
|
||||
-> LLMResult:
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -359,15 +376,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message.content,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
@ -378,7 +394,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
@ -395,17 +411,19 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
|
||||
|
||||
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
|
||||
As a temporary solution we use GPT2 tokenizer instead.
|
||||
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
@ -414,10 +432,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
value = text
|
||||
|
||||
if key == "function_call":
|
||||
@ -452,36 +470,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
# calculate num tokens for function object
|
||||
num_tokens += tokens('name')
|
||||
num_tokens += tokens("name")
|
||||
num_tokens += tokens(tool.name)
|
||||
num_tokens += tokens('description')
|
||||
num_tokens += tokens("description")
|
||||
num_tokens += tokens(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += tokens('parameters')
|
||||
num_tokens += tokens('type')
|
||||
num_tokens += tokens("parameters")
|
||||
num_tokens += tokens("type")
|
||||
num_tokens += tokens(parameters.get("type"))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += tokens('properties')
|
||||
for key, value in parameters.get('properties').items():
|
||||
if "properties" in parameters:
|
||||
num_tokens += tokens("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += tokens(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += tokens(field_key)
|
||||
if field_key == 'enum':
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(enum_field)
|
||||
else:
|
||||
num_tokens += tokens(field_key)
|
||||
num_tokens += tokens(str(field_value))
|
||||
if 'required' in parameters:
|
||||
num_tokens += tokens('required')
|
||||
for required_field in parameters['required']:
|
||||
if "required" in parameters:
|
||||
num_tokens += tokens("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(required_field)
|
||||
|
||||
|
||||
@ -20,12 +20,9 @@ class CohereProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.RERANK)
|
||||
|
||||
# Use `rerank-english-v2.0` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -85,7 +91,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
return self._generate(
|
||||
@ -95,11 +101,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -136,30 +147,37 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content='ping')],
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
'max_tokens': 20,
|
||||
'temperature': 0,
|
||||
"max_tokens": 20,
|
||||
"temperature": 0,
|
||||
},
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content='ping')],
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
'max_tokens': 20,
|
||||
'temperature': 0,
|
||||
"max_tokens": 20,
|
||||
"temperature": 0,
|
||||
},
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm model
|
||||
|
||||
@ -173,17 +191,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
if stop:
|
||||
model_parameters['end_sequences'] = stop
|
||||
model_parameters["end_sequences"] = stop
|
||||
|
||||
if stream:
|
||||
response = client.generate_stream(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
request_options=RequestOptions(max_retries=0),
|
||||
)
|
||||
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
@ -192,14 +210,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
request_options=RequestOptions(max_retries=0),
|
||||
)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
@ -212,9 +230,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
assistant_text = response.generations[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = int(response.meta.billed_units.input_tokens)
|
||||
@ -225,17 +241,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Iterator[GenerateStreamedResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Iterator[GenerateStreamedResponse],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@ -245,7 +262,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
for chunk in response:
|
||||
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
|
||||
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
|
||||
@ -255,9 +272,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
@ -267,7 +282,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
index += 1
|
||||
@ -277,9 +292,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
model,
|
||||
credentials,
|
||||
[AssistantPromptMessage(content=full_assistant_content)]
|
||||
model, credentials, [AssistantPromptMessage(content=full_assistant_content)]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -290,20 +303,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=chunk.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
break
|
||||
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
|
||||
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
|
||||
raise InvokeBadRequestError(chunk.err)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
|
||||
@ -318,27 +338,28 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
if stop:
|
||||
model_parameters['stop_sequences'] = stop
|
||||
model_parameters["stop_sequences"] = stop
|
||||
|
||||
if tools:
|
||||
if len(tools) == 1:
|
||||
raise ValueError("Cohere tool call requires at least two tools to be specified.")
|
||||
|
||||
model_parameters['tools'] = self._convert_tools(tools)
|
||||
model_parameters["tools"] = self._convert_tools(tools)
|
||||
|
||||
message, chat_histories, tool_results \
|
||||
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories(
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
if tool_results:
|
||||
model_parameters['tool_results'] = tool_results
|
||||
model_parameters["tool_results"] = tool_results
|
||||
|
||||
# chat model
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
real_model = model.removesuffix("-chat")
|
||||
|
||||
if stream:
|
||||
response = client.chat_stream(
|
||||
@ -346,7 +367,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
request_options=RequestOptions(max_retries=0),
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
@ -356,14 +377,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
request_options=RequestOptions(max_retries=0),
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
def _handle_chat_generate_response(
|
||||
self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -380,19 +401,15 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
for cohere_tool_call in response.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters)
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
@ -403,17 +420,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Iterator[StreamedChatResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Iterator[StreamedChatResponse],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
@ -423,17 +441,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
|
||||
def final_response(full_text: str,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
index: int,
|
||||
finish_reason: Optional[str] = None) -> LLMResultChunk:
|
||||
def final_response(
|
||||
full_text: str,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
index: int,
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls)
|
||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
@ -444,14 +461,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
|
||||
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamedChatResponse_TextGeneration):
|
||||
@ -462,9 +479,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
@ -474,7 +489,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
index += 1
|
||||
@ -484,11 +499,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
for cohere_tool_call in chunk.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters)
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
|
||||
@ -496,8 +510,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
|
||||
index += 1
|
||||
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
||||
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(
|
||||
self, prompt_messages: list[PromptMessage]
|
||||
) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
||||
"""
|
||||
Convert prompt messages to message and chat histories
|
||||
:param prompt_messages: prompt messages
|
||||
@ -510,13 +525,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_message = cast(AssistantPromptMessage, prompt_message)
|
||||
if prompt_message.tool_calls:
|
||||
for tool_call in prompt_message.tool_calls:
|
||||
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call.function.name,
|
||||
parameters=json.loads(tool_call.function.arguments)
|
||||
),
|
||||
outputs=[]
|
||||
))
|
||||
latest_tool_call_n_outputs.append(
|
||||
ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments)
|
||||
),
|
||||
outputs=[],
|
||||
)
|
||||
)
|
||||
else:
|
||||
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
||||
if cohere_prompt_message:
|
||||
@ -529,12 +545,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
|
||||
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call_n_outputs.call.name,
|
||||
parameters=tool_call_n_outputs.call.parameters
|
||||
name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters
|
||||
),
|
||||
outputs=[{
|
||||
"result": prompt_message.content
|
||||
}]
|
||||
outputs=[{"result": prompt_message.content}],
|
||||
)
|
||||
break
|
||||
i += 1
|
||||
@ -556,7 +569,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
latest_message = chat_histories.pop()
|
||||
message = latest_message.message
|
||||
else:
|
||||
raise ValueError('Prompt messages is empty')
|
||||
raise ValueError("Prompt messages is empty")
|
||||
|
||||
return message, chat_histories, latest_tool_call_n_outputs
|
||||
|
||||
@ -569,7 +582,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
chat_message = ChatMessage(role="USER", message=message.content)
|
||||
else:
|
||||
sub_message_text = ''
|
||||
sub_message_text = ""
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
@ -597,8 +610,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
properties = tool.parameters['properties']
|
||||
required_properties = tool.parameters['required']
|
||||
properties = tool.parameters["properties"]
|
||||
required_properties = tool.parameters["required"]
|
||||
|
||||
parameter_definitions = {}
|
||||
for p_key, p_val in properties.items():
|
||||
@ -606,21 +619,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
if p_key in required_properties:
|
||||
required = True
|
||||
|
||||
desc = p_val['description']
|
||||
if 'enum' in p_val:
|
||||
desc += (f"; Only accepts one of the following predefined options: "
|
||||
f"[{', '.join(p_val['enum'])}]")
|
||||
desc = p_val["description"]
|
||||
if "enum" in p_val:
|
||||
desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]"
|
||||
|
||||
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
|
||||
description=desc,
|
||||
type=p_val['type'],
|
||||
required=required
|
||||
description=desc, type=p_val["type"], required=required
|
||||
)
|
||||
|
||||
cohere_tool = Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameter_definitions=parameter_definitions
|
||||
name=tool.name, description=tool.description, parameter_definitions=parameter_definitions
|
||||
)
|
||||
|
||||
cohere_tools.append(cohere_tool)
|
||||
@ -637,12 +645,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: number of tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model
|
||||
)
|
||||
response = client.tokenize(text=text, model=model)
|
||||
|
||||
return len(response.tokens)
|
||||
|
||||
@ -658,30 +663,30 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
real_model = model.removesuffix("-chat")
|
||||
|
||||
return self._num_tokens_from_string(real_model, credentials, message_str)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Cohere supports fine-tuning of their models. This method returns the schema of the base model
|
||||
but renamed to the fine-tuned model name.
|
||||
Cohere supports fine-tuning of their models. This method returns the schema of the base model
|
||||
but renamed to the fine-tuned model name.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
|
||||
:return: model schema
|
||||
:return: model schema
|
||||
"""
|
||||
# get model schema
|
||||
models = self.predefined_models()
|
||||
model_map = {model.model: model for model in models}
|
||||
|
||||
mode = credentials.get('mode')
|
||||
mode = credentials.get("mode")
|
||||
|
||||
if mode == 'chat':
|
||||
base_model_schema = model_map['command-light-chat']
|
||||
if mode == "chat":
|
||||
base_model_schema = model_map["command-light-chat"]
|
||||
else:
|
||||
base_model_schema = model_map['command-light']
|
||||
base_model_schema = model_map["command-light"]
|
||||
|
||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||
|
||||
@ -691,16 +696,13 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
zh_Hans=model,
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=list(base_model_schema_features),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties=dict(base_model_schema_model_properties.items()),
|
||||
parameter_rules=list(base_model_schema_parameters_rules),
|
||||
pricing=base_model_schema.pricing
|
||||
pricing=base_model_schema.pricing,
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -716,22 +718,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
|
||||
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
|
||||
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
cohere.errors.forbidden_error.ForbiddenError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel):
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -38,20 +44,17 @@ class CohereRerankModel(RerankModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=docs
|
||||
)
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
response = client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n,
|
||||
return_documents=True,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
request_options=RequestOptions(max_retries=0),
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
@ -70,10 +73,7 @@ class CohereRerankModel(RerankModel):
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=rerank_documents
|
||||
)
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -94,7 +94,7 @@ class CohereRerankModel(RerankModel):
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -110,22 +110,16 @@ class CohereRerankModel(RerankModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
|
||||
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
|
||||
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
cohere.errors.forbidden_error.ForbiddenError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -46,14 +46,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokenize_response = self._tokenize(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text
|
||||
)
|
||||
tokenize_response = self._tokenize(model=model, credentials=credentials, text=text)
|
||||
|
||||
for j in range(0, len(tokenize_response), context_size):
|
||||
tokens += [tokenize_response[j: j + context_size]]
|
||||
tokens += [tokenize_response[j : j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
@ -62,9 +58,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=["".join(token) for token in tokens[i: i + max_chunks]]
|
||||
model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
@ -80,9 +74,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=[" "]
|
||||
model=model, credentials=credentials, texts=[" "]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
@ -92,17 +84,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
@ -116,14 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
full_text = ' '.join(texts)
|
||||
full_text = " ".join(texts)
|
||||
|
||||
try:
|
||||
response = self._tokenize(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=full_text
|
||||
)
|
||||
response = self._tokenize(model=model, credentials=credentials, text=full_text)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@ -141,14 +121,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
return []
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model,
|
||||
offline=False,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0))
|
||||
|
||||
return response.token_strings
|
||||
|
||||
@ -162,11 +137,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -180,14 +151,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
|
||||
# call embedding model
|
||||
response = client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query',
|
||||
request_options=RequestOptions(max_retries=1)
|
||||
input_type="search_document" if len(texts) > 1 else "search_query",
|
||||
request_options=RequestOptions(max_retries=1),
|
||||
)
|
||||
|
||||
return response.embeddings, int(response.meta.billed_units.input_tokens)
|
||||
@ -203,10 +174,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -217,7 +185,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -233,22 +201,16 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError],
|
||||
InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError],
|
||||
InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
cohere.errors.forbidden_error.ForbiddenError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@ -7,9 +7,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class DeepSeekProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -22,12 +20,9 @@ class DeepSeekProvider(ModelProvider):
|
||||
|
||||
# Use `deepseek-chat` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='deepseek-chat',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="deepseek-chat", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -13,12 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag
|
||||
|
||||
|
||||
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
@ -27,10 +32,8 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
@ -48,8 +51,9 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
@ -69,10 +73,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
@ -103,11 +107,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['openai_api_base']='https://api.deepseek.com'
|
||||
credentials["mode"] = "chat"
|
||||
credentials["openai_api_key"] = credentials["api_key"]
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["openai_api_base"] = "https://api.deepseek.com"
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
parsed_url = urlparse(credentials["endpoint_url"])
|
||||
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import logging
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -18,11 +18,9 @@ class FishAudioProvider(ModelProvider):
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TTS)
|
||||
model_instance.validate_credentials(
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
@ -12,9 +12,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
Model class for Fish.audio Text to Speech model.
|
||||
"""
|
||||
|
||||
def get_tts_model_voices(
|
||||
self, model: str, credentials: dict, language: Optional[str] = None
|
||||
) -> list:
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
api_base = credentials.get("api_base", "https://api.fish.audio")
|
||||
api_key = credentials.get("api_key")
|
||||
use_public_models = credentials.get("use_public_models", "false") == "true"
|
||||
@ -68,9 +66,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict, user: Optional[str] = None
|
||||
) -> None:
|
||||
def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Validate credentials for text2speech model
|
||||
|
||||
@ -91,9 +87,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(
|
||||
self, model: str, credentials: dict, content_text: str, voice: str
|
||||
) -> any:
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
:param model: model name
|
||||
@ -106,12 +100,10 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
try:
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(
|
||||
content_text, max_length=word_limit
|
||||
)
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
else:
|
||||
sentences = [content_text.strip()]
|
||||
|
||||
|
||||
for i in range(len(sentences)):
|
||||
yield from self._tts_invoke_streaming_sentence(
|
||||
credentials=credentials, content_text=sentences[i], voice=voice
|
||||
@ -120,9 +112,7 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming_sentence(
|
||||
self, credentials: dict, content_text: str, voice: Optional[str] = None
|
||||
) -> any:
|
||||
def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
|
||||
@ -141,20 +131,14 @@ class FishAudioText2SpeechModel(TTSModel):
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
api_url + "/v1/tts",
|
||||
json={
|
||||
"text": content_text,
|
||||
"reference_id": voice,
|
||||
"latency": latency
|
||||
},
|
||||
json={"text": content_text, "reference_id": voice, "latency": latency},
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(
|
||||
f"Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}")
|
||||
yield from response.iter_bytes()
|
||||
|
||||
@property
|
||||
|
||||
@ -20,12 +20,9 @@ class GoogleProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-pro` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='gemini-pro',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="gemini-pro", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -49,12 +49,17 @@ if you are not sure about the structure.
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -70,9 +75,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -85,7 +95,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
@ -95,13 +105,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
@ -117,14 +124,16 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
type=glm.Type.OBJECT,
|
||||
properties={
|
||||
key: {
|
||||
'type_': value.get('type', 'string').upper(),
|
||||
'description': value.get('description', ''),
|
||||
'enum': value.get('enum', [])
|
||||
} for key, value in tool.parameters.get('properties', {}).items()
|
||||
"type_": value.get("type", "string").upper(),
|
||||
"description": value.get("description", ""),
|
||||
"enum": value.get("enum", []),
|
||||
}
|
||||
for key, value in tool.parameters.get("properties", {}).items()
|
||||
},
|
||||
required=tool.parameters.get('required', [])
|
||||
required=tool.parameters.get("required", []),
|
||||
),
|
||||
) for tool in tools
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
@ -136,20 +145,25 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -163,14 +177,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
google_model = genai.GenerativeModel(
|
||||
model_name=model
|
||||
)
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
@ -180,7 +192,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
@ -194,7 +206,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
google_model._client = new_custom_client
|
||||
|
||||
safety_settings={
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
@ -203,13 +215,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
**config_kwargs
|
||||
),
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
stream=stream,
|
||||
safety_settings=safety_settings,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||
request_options={"timeout": 600}
|
||||
request_options={"timeout": 600},
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -217,8 +227,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
@ -229,9 +240,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
@ -250,8 +259,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@ -264,9 +274,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
index = -1
|
||||
for chunk in response:
|
||||
for part in chunk.parts:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
||||
|
||||
if part.text:
|
||||
assistant_prompt_message.content += part.text
|
||||
@ -275,36 +283,31 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
assistant_prompt_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.function_call.name,
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
)
|
||||
arguments=json.dumps(dict(part.function_call.args.items())),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
index += 1
|
||||
|
||||
|
||||
if not response._done:
|
||||
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
else:
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -312,8 +315,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk.candidates[0].finish_reason),
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
@ -328,9 +331,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@ -353,65 +354,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = {
|
||||
"role": "user",
|
||||
"parts": []
|
||||
}
|
||||
if (isinstance(message.content, str)):
|
||||
glm_content['parts'].append(to_part(message.content))
|
||||
glm_content = {"role": "user", "parts": []}
|
||||
if isinstance(message.content, str):
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
else:
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content['parts'].append(to_part(c.data))
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(',', 1)
|
||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}}
|
||||
glm_content['parts'].append(blob)
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
glm_content = {
|
||||
"role": "model",
|
||||
"parts": []
|
||||
}
|
||||
glm_content = {"role": "model", "parts": []}
|
||||
if message.content:
|
||||
glm_content['parts'].append(to_part(message.content))
|
||||
glm_content["parts"].append(to_part(message.content))
|
||||
if message.tool_calls:
|
||||
glm_content["parts"].append(to_part(glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
)))
|
||||
glm_content["parts"].append(
|
||||
to_part(
|
||||
glm.FunctionCall(
|
||||
name=message.tool_calls[0].function.name,
|
||||
args=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
)
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {
|
||||
"role": "user",
|
||||
"parts": [to_part(message.content)]
|
||||
}
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
"parts": [glm.Part(function_response=glm.FunctionResponse(
|
||||
name=message.name,
|
||||
response={
|
||||
"response": message.content
|
||||
}
|
||||
))]
|
||||
"parts": [
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name, response={"response": message.content}
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
@ -423,25 +420,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
exceptions.RetryError
|
||||
],
|
||||
InvokeConnectionError: [exceptions.RetryError],
|
||||
InvokeServerUnavailableError: [
|
||||
exceptions.ServiceUnavailable,
|
||||
exceptions.InternalServerError,
|
||||
exceptions.BadGateway,
|
||||
exceptions.GatewayTimeout,
|
||||
exceptions.DeadlineExceeded
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
exceptions.ResourceExhausted,
|
||||
exceptions.TooManyRequests
|
||||
exceptions.DeadlineExceeded,
|
||||
],
|
||||
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
|
||||
InvokeAuthorizationError: [
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.PermissionDenied,
|
||||
exceptions.Unauthenticated,
|
||||
exceptions.Forbidden
|
||||
exceptions.Forbidden,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
exceptions.BadRequest,
|
||||
@ -457,5 +449,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
exceptions.PreconditionFailed,
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GroqProvider(ModelProvider):
|
||||
|
||||
class GroqProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -18,12 +18,9 @@ class GroqProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama3-8b-8192',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
|
||||
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
@ -21,6 +27,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
|
||||
|
||||
credentials["mode"] = "chat"
|
||||
credentials["endpoint_url"] = "https://api.groq.com/openai/v1"
|
||||
|
||||
@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
|
||||
|
||||
class _CommonHuggingfaceHub:
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeBadRequestError: [
|
||||
HfHubHTTPError,
|
||||
BadRequestError
|
||||
]
|
||||
}
|
||||
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}
|
||||
|
||||
@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceHubProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
|
||||
@ -29,16 +29,23 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu
|
||||
|
||||
|
||||
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
|
||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
model = credentials['huggingfacehub_endpoint_url']
|
||||
|
||||
if 'baichuan' in model.lower():
|
||||
if "baichuan" in model.lower():
|
||||
stream = False
|
||||
|
||||
response = client.text_generation(
|
||||
@ -47,98 +54,100 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
stream=stream,
|
||||
model=model,
|
||||
stop_sequences=stop,
|
||||
**model_parameters)
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, prompt_messages, response)
|
||||
|
||||
return self._handle_generate_response(model, credentials, prompt_messages, response)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
if 'huggingfacehub_api_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
|
||||
if "huggingfacehub_api_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
||||
|
||||
if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'):
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
|
||||
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
||||
|
||||
if 'huggingfacehub_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.')
|
||||
if "huggingfacehub_api_token" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.")
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
if "huggingfacehub_endpoint_url" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.")
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
|
||||
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
||||
credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'],
|
||||
model)
|
||||
if "task_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.")
|
||||
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
|
||||
credentials["task_type"] = self._get_hosted_model_task_type(
|
||||
credentials["huggingfacehub_api_token"], model
|
||||
)
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation"):
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, '
|
||||
'text-generation.')
|
||||
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
|
||||
raise CredentialsValidateFailedError(
|
||||
"Huggingface Hub Task Type must be one of text2text-generation, " "text-generation."
|
||||
)
|
||||
|
||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
model = credentials['huggingfacehub_endpoint_url']
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
try:
|
||||
client.text_generation(
|
||||
prompt='Who are you?',
|
||||
stream=True,
|
||||
model=model)
|
||||
client.text_generation(prompt="Who are you?", stream=True, model=model)
|
||||
except BadRequestError as e:
|
||||
raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. '
|
||||
'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.')
|
||||
raise CredentialsValidateFailedError(
|
||||
"Only available for models running on with the `text-generation-inference`. "
|
||||
"To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference."
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
|
||||
},
|
||||
parameter_rules=self._get_customizable_model_parameter_rules()
|
||||
model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value},
|
||||
parameter_rules=self._get_customizable_model_parameter_rules(),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
|
||||
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
|
||||
DefaultParameterName.TEMPERATURE).copy()
|
||||
temperature_rule_dict['name'] = 'temperature'
|
||||
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy()
|
||||
temperature_rule_dict["name"] = "temperature"
|
||||
temperature_rule = ParameterRule(**temperature_rule_dict)
|
||||
temperature_rule.default = 0.5
|
||||
|
||||
top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy()
|
||||
top_p_rule_dict['name'] = 'top_p'
|
||||
top_p_rule_dict["name"] = "top_p"
|
||||
top_p_rule = ParameterRule(**top_p_rule_dict)
|
||||
top_p_rule.default = 0.5
|
||||
|
||||
top_k_rule = ParameterRule(
|
||||
name='top_k',
|
||||
name="top_k",
|
||||
label={
|
||||
'en_US': 'Top K',
|
||||
'zh_Hans': 'Top K',
|
||||
"en_US": "Top K",
|
||||
"zh_Hans": "Top K",
|
||||
},
|
||||
type='int',
|
||||
type="int",
|
||||
help={
|
||||
'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.',
|
||||
'zh_Hans': '保留的最高概率词汇标记的数量。',
|
||||
"en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
||||
"zh_Hans": "保留的最高概率词汇标记的数量。",
|
||||
},
|
||||
required=False,
|
||||
default=2,
|
||||
@ -148,15 +157,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
)
|
||||
|
||||
max_new_tokens = ParameterRule(
|
||||
name='max_new_tokens',
|
||||
name="max_new_tokens",
|
||||
label={
|
||||
'en_US': 'Max New Tokens',
|
||||
'zh_Hans': '最大新标记',
|
||||
"en_US": "Max New Tokens",
|
||||
"zh_Hans": "最大新标记",
|
||||
},
|
||||
type='int',
|
||||
type="int",
|
||||
help={
|
||||
'en_US': 'Maximum number of generated tokens.',
|
||||
'zh_Hans': '生成的标记的最大数量。',
|
||||
"en_US": "Maximum number of generated tokens.",
|
||||
"zh_Hans": "生成的标记的最大数量。",
|
||||
},
|
||||
required=False,
|
||||
default=20,
|
||||
@ -166,30 +175,30 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
)
|
||||
|
||||
seed = ParameterRule(
|
||||
name='seed',
|
||||
name="seed",
|
||||
label={
|
||||
'en_US': 'Random sampling seed',
|
||||
'zh_Hans': '随机采样种子',
|
||||
"en_US": "Random sampling seed",
|
||||
"zh_Hans": "随机采样种子",
|
||||
},
|
||||
type='int',
|
||||
type="int",
|
||||
help={
|
||||
'en_US': 'Random sampling seed.',
|
||||
'zh_Hans': '随机采样种子。',
|
||||
"en_US": "Random sampling seed.",
|
||||
"zh_Hans": "随机采样种子。",
|
||||
},
|
||||
required=False,
|
||||
precision=0,
|
||||
)
|
||||
|
||||
repetition_penalty = ParameterRule(
|
||||
name='repetition_penalty',
|
||||
name="repetition_penalty",
|
||||
label={
|
||||
'en_US': 'Repetition Penalty',
|
||||
'zh_Hans': '重复惩罚',
|
||||
"en_US": "Repetition Penalty",
|
||||
"zh_Hans": "重复惩罚",
|
||||
},
|
||||
type='float',
|
||||
type="float",
|
||||
help={
|
||||
'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.',
|
||||
'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。',
|
||||
"en_US": "The parameter for repetition penalty. 1.0 means no penalty.",
|
||||
"zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。",
|
||||
},
|
||||
required=False,
|
||||
precision=1,
|
||||
@ -197,11 +206,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
|
||||
return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty]
|
||||
|
||||
def _handle_generate_stream_response(self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
response: Generator) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator
|
||||
) -> Generator:
|
||||
index = -1
|
||||
for chunk in response:
|
||||
# skip special tokens
|
||||
@ -210,9 +217,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
|
||||
index += 1
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk.token.text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text)
|
||||
|
||||
if chunk.details:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
@ -240,15 +245,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any
|
||||
) -> LLMResult:
|
||||
if isinstance(response, str):
|
||||
content = response
|
||||
else:
|
||||
content = response.generated_text
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=content)
|
||||
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
@ -270,15 +275,14 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
|
||||
try:
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
raise ValueError(f"Model {model_name} not found.")
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
if "inference" in model_info.cardData and not model_info.cardData["inference"]:
|
||||
raise ValueError(f"Inference API has been turned off for this model {model_name}.")
|
||||
|
||||
valid_tasks = ("text2text-generation", "text-generation")
|
||||
if model_info.pipeline_tag not in valid_tasks:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {valid_tasks}.")
|
||||
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{str(e)}")
|
||||
|
||||
@ -287,10 +291,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
|
||||
@ -13,40 +13,30 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
|
||||
|
||||
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
|
||||
HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
|
||||
|
||||
|
||||
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
|
||||
execute_model = model
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
execute_model = credentials['huggingfacehub_endpoint_url']
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
execute_model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
output = client.post(
|
||||
json={
|
||||
"inputs": texts,
|
||||
"options": {
|
||||
"wait_for_model": False,
|
||||
"use_cache": False
|
||||
}
|
||||
},
|
||||
model=execute_model)
|
||||
json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model
|
||||
)
|
||||
|
||||
embeddings = json.loads(output.decode())
|
||||
|
||||
tokens = self.get_num_tokens(model, credentials, texts)
|
||||
usage = self._calc_response_usage(model, credentials, tokens)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=self._mean_pooling(embeddings),
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
num_tokens = 0
|
||||
@ -56,52 +46,48 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
if 'huggingfacehub_api_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
|
||||
if "huggingfacehub_api_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
||||
|
||||
if 'huggingfacehub_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.')
|
||||
if "huggingfacehub_api_token" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.")
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'huggingface_namespace' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.')
|
||||
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||
if "huggingface_namespace" not in credentials:
|
||||
raise CredentialsValidateFailedError(
|
||||
"Huggingface Hub User Name / Organization Name must be provided."
|
||||
)
|
||||
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
|
||||
if "huggingfacehub_endpoint_url" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.")
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
|
||||
if "task_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.")
|
||||
|
||||
if credentials['task_type'] != 'feature-extraction':
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.')
|
||||
if credentials["task_type"] != "feature-extraction":
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.")
|
||||
|
||||
self._check_endpoint_url_model_repository_name(credentials, model)
|
||||
|
||||
model = credentials['huggingfacehub_endpoint_url']
|
||||
model = credentials["huggingfacehub_endpoint_url"]
|
||||
|
||||
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
||||
self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'],
|
||||
model)
|
||||
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
|
||||
self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model)
|
||||
else:
|
||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
||||
|
||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
||||
client.feature_extraction(text='hello world', model=model)
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
client.feature_extraction(text="hello world", model=model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
'context_size': 10000,
|
||||
'max_chunks': 1
|
||||
}
|
||||
model_properties={"context_size": 10000, "max_chunks": 1},
|
||||
)
|
||||
return entity
|
||||
|
||||
@ -128,24 +114,20 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
|
||||
try:
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
raise ValueError(f"Model {model_name} not found.")
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
if "inference" in model_info.cardData and not model_info.cardData["inference"]:
|
||||
raise ValueError(f"Inference API has been turned off for this model {model_name}.")
|
||||
|
||||
valid_tasks = "feature-extraction"
|
||||
if model_info.pipeline_tag not in valid_tasks:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {valid_tasks}.")
|
||||
raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{str(e)}")
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -156,7 +138,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -166,25 +148,26 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
'Content-Type': 'application/json'
|
||||
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError('User Name or Organization Name is invalid.')
|
||||
raise ValueError("User Name or Organization Name is invalid.")
|
||||
|
||||
model_repository_name = ''
|
||||
model_repository_name = ""
|
||||
|
||||
for item in response.json().get("items", []):
|
||||
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
|
||||
if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]:
|
||||
model_repository_name = item.get("model", {}).get("repository")
|
||||
break
|
||||
|
||||
if model_repository_name != model_name:
|
||||
raise ValueError(
|
||||
f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
|
||||
f"Model Name {model_name} is invalid. Please check it on the inference endpoints console."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceTeiProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
|
||||
@ -47,29 +47,29 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
|
||||
if server_url.endswith('/'):
|
||||
if server_url.endswith("/"):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
try:
|
||||
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
||||
|
||||
rerank_documents = []
|
||||
for result in results:
|
||||
for result in results:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['text'],
|
||||
score=result['score'],
|
||||
index=result["index"],
|
||||
text=result["text"],
|
||||
score=result["score"],
|
||||
)
|
||||
if score_threshold is None or result['score'] >= score_threshold:
|
||||
if score_threshold is None or result["score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
if top_n is not None and len(rerank_documents) >= top_n:
|
||||
break
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -80,21 +80,21 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
if extra_args.model_type != 'reranker':
|
||||
raise CredentialsValidateFailedError('Current model is not a rerank model')
|
||||
if extra_args.model_type != "reranker":
|
||||
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
||||
|
||||
credentials['context_size'] = extra_args.max_input_length
|
||||
credentials["context_size"] = extra_args.max_input_length
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query='Whose kasumi',
|
||||
query="Whose kasumi",
|
||||
docs=[
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
|
||||
'and she leads a team named PopiParty.',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
@ -129,7 +129,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
@ -31,16 +31,16 @@ class TeiHelper:
|
||||
with cache_lock:
|
||||
if model_name not in cache:
|
||||
cache[model_name] = {
|
||||
'expires': time() + 300,
|
||||
'value': TeiHelper._get_tei_extra_parameter(server_url),
|
||||
"expires": time() + 300,
|
||||
"value": TeiHelper._get_tei_extra_parameter(server_url),
|
||||
}
|
||||
return cache[model_name]['value']
|
||||
return cache[model_name]["value"]
|
||||
|
||||
@staticmethod
|
||||
def _clean_cache() -> None:
|
||||
try:
|
||||
with cache_lock:
|
||||
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
|
||||
expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()]
|
||||
for model_uid in expired_keys:
|
||||
del cache[model_uid]
|
||||
except RuntimeError as e:
|
||||
@ -52,40 +52,38 @@ class TeiHelper:
|
||||
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||
"""
|
||||
|
||||
url = str(URL(server_url) / 'info')
|
||||
url = str(URL(server_url) / "info")
|
||||
|
||||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||
session = Session()
|
||||
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||
session.mount("http://", HTTPAdapter(max_retries=3))
|
||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
|
||||
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
|
||||
f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}"
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
model_type = response_json.get('model_type', {})
|
||||
model_type = response_json.get("model_type", {})
|
||||
if len(model_type.keys()) < 1:
|
||||
raise RuntimeError('model_type is empty')
|
||||
raise RuntimeError("model_type is empty")
|
||||
model_type = list(model_type.keys())[0]
|
||||
if model_type not in ['embedding', 'reranker']:
|
||||
raise RuntimeError(f'invalid model_type: {model_type}')
|
||||
|
||||
max_input_length = response_json.get('max_input_length', 512)
|
||||
max_client_batch_size = response_json.get('max_client_batch_size', 1)
|
||||
if model_type not in ["embedding", "reranker"]:
|
||||
raise RuntimeError(f"invalid model_type: {model_type}")
|
||||
|
||||
max_input_length = response_json.get("max_input_length", 512)
|
||||
max_client_batch_size = response_json.get("max_client_batch_size", 1)
|
||||
|
||||
return TeiModelExtraParameter(
|
||||
model_type=model_type,
|
||||
max_input_length=max_input_length,
|
||||
max_client_batch_size=max_client_batch_size
|
||||
model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
"""
|
||||
@ -116,12 +114,12 @@ class TeiHelper:
|
||||
:param texts: texts to tokenize
|
||||
"""
|
||||
resp = httpx.post(
|
||||
f'{server_url}/tokenize',
|
||||
json={'inputs': texts},
|
||||
f"{server_url}/tokenize",
|
||||
json={"inputs": texts},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
"""
|
||||
@ -149,8 +147,8 @@ class TeiHelper:
|
||||
"""
|
||||
# Use OpenAI compatible API here, which has usage tracking
|
||||
resp = httpx.post(
|
||||
f'{server_url}/v1/embeddings',
|
||||
json={'input': texts},
|
||||
f"{server_url}/v1/embeddings",
|
||||
json={"input": texts},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
@ -173,11 +171,11 @@ class TeiHelper:
|
||||
:param texts: texts to rerank
|
||||
:param candidates: candidates to rerank
|
||||
"""
|
||||
params = {'query': query, 'texts': docs, 'return_text': True}
|
||||
params = {"query": query, "texts": docs, "return_text": True}
|
||||
|
||||
response = httpx.post(
|
||||
server_url + '/rerank',
|
||||
server_url + "/rerank",
|
||||
json=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@ -40,12 +40,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
|
||||
if server_url.endswith('/'):
|
||||
if server_url.endswith("/"):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
@ -58,7 +57,6 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
|
||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||
|
||||
# Check if the number of tokens is larger than the context size
|
||||
num_tokens = len(tokenize_result)
|
||||
|
||||
@ -66,20 +64,22 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
# Find the best cutoff point
|
||||
pre_special_token_count = 0
|
||||
for token in tokenize_result:
|
||||
if token['special']:
|
||||
if token["special"]:
|
||||
pre_special_token_count += 1
|
||||
else:
|
||||
break
|
||||
rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
|
||||
rest_special_token_count = (
|
||||
len([token for token in tokenize_result if token["special"]]) - pre_special_token_count
|
||||
)
|
||||
|
||||
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
|
||||
token_cutoff = context_size - rest_special_token_count - 20
|
||||
|
||||
# Find the cutoff index
|
||||
cutpoint_token = tokenize_result[token_cutoff]
|
||||
cutoff = cutpoint_token['start']
|
||||
cutoff = cutpoint_token["start"]
|
||||
|
||||
inputs.append(text[0: cutoff])
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
@ -92,12 +92,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
for i in _iter:
|
||||
iter_texts = inputs[i : i + max_chunks]
|
||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||
embeddings = results['data']
|
||||
embeddings = [embedding['embedding'] for embedding in embeddings]
|
||||
embeddings = results["data"]
|
||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||
batched_embeddings.extend(embeddings)
|
||||
|
||||
usage = results['usage']
|
||||
used_tokens += usage['total_tokens']
|
||||
usage = results["usage"]
|
||||
used_tokens += usage["total_tokens"]
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
@ -117,9 +117,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
|
||||
if server_url.endswith('/'):
|
||||
if server_url.endswith("/"):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||
@ -135,15 +135,15 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||
print(extra_args)
|
||||
if extra_args.model_type != 'embedding':
|
||||
raise CredentialsValidateFailedError('Current model is not a embedding model')
|
||||
if extra_args.model_type != "embedding":
|
||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||
|
||||
credentials['context_size'] = extra_args.max_input_length
|
||||
credentials['max_chunks'] = extra_args.max_client_batch_size
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
credentials["context_size"] = extra_args.max_input_length
|
||||
credentials["max_chunks"] = extra_args.max_client_batch_size
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -195,8 +195,8 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
|
||||
ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
||||
},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HunyuanProvider(ModelProvider):
|
||||
|
||||
class HunyuanProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -19,12 +19,9 @@ class HunyuanProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `hunyuan-standard` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -23,21 +23,27 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
request = models.ChatCompletionsRequest()
|
||||
messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages)
|
||||
|
||||
custom_parameters = {
|
||||
'Temperature': model_parameters.get('temperature', 0.0),
|
||||
'TopP': model_parameters.get('top_p', 1.0),
|
||||
'EnableEnhancement': model_parameters.get('enable_enhance', True)
|
||||
"Temperature": model_parameters.get("temperature", 0.0),
|
||||
"TopP": model_parameters.get("top_p", 1.0),
|
||||
"EnableEnhancement": model_parameters.get("enable_enhance", True),
|
||||
}
|
||||
|
||||
params = {
|
||||
@ -47,16 +53,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
**custom_parameters,
|
||||
}
|
||||
# add Tools and ToolChoice
|
||||
if (tools and len(tools) > 0):
|
||||
params['ToolChoice'] = "auto"
|
||||
params['Tools'] = [{
|
||||
"Type": "function",
|
||||
"Function": {
|
||||
"Name": tool.name,
|
||||
"Description": tool.description,
|
||||
"Parameters": json.dumps(tool.parameters)
|
||||
if tools and len(tools) > 0:
|
||||
params["ToolChoice"] = "auto"
|
||||
params["Tools"] = [
|
||||
{
|
||||
"Type": "function",
|
||||
"Function": {
|
||||
"Name": tool.name,
|
||||
"Description": tool.description,
|
||||
"Parameters": json.dumps(tool.parameters),
|
||||
},
|
||||
}
|
||||
} for tool in tools]
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
request.from_json_string(json.dumps(params))
|
||||
response = client.ChatCompletions(request)
|
||||
@ -76,22 +85,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": [{
|
||||
"Role": "user",
|
||||
"Content": "hello"
|
||||
}],
|
||||
"Messages": [{"Role": "user", "Content": "hello"}],
|
||||
"TopP": 1,
|
||||
"Temperature": 0,
|
||||
"Stream": False
|
||||
"Stream": False,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
client.ChatCompletions(req)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
def _setup_hunyuan_client(self, credentials):
|
||||
secret_id = credentials['secret_id']
|
||||
secret_key = credentials['secret_key']
|
||||
secret_id = credentials["secret_id"]
|
||||
secret_key = credentials["secret_key"]
|
||||
cred = credential.Credential(secret_id, secret_key)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
@ -106,92 +112,96 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
tool_calls = message.tool_calls
|
||||
if (tool_calls and len(tool_calls) > 0):
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
dict_tool_calls = [
|
||||
{
|
||||
"Id": tool_call.id,
|
||||
"Type": tool_call.type,
|
||||
"Function": {
|
||||
"Name": tool_call.function.name,
|
||||
"Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}"
|
||||
}
|
||||
} for tool_call in tool_calls]
|
||||
|
||||
dict_list.append({
|
||||
"Role": message.role.value,
|
||||
# fix set content = "" while tool_call request
|
||||
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
|
||||
"Content": " ", # message.content if (message.content is not None) else "",
|
||||
"ToolCalls": dict_tool_calls
|
||||
})
|
||||
"Arguments": tool_call.function.arguments
|
||||
if (tool_call.function.arguments == "")
|
||||
else "{}",
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
dict_list.append(
|
||||
{
|
||||
"Role": message.role.value,
|
||||
# fix set content = "" while tool_call request
|
||||
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
|
||||
"Content": " ", # message.content if (message.content is not None) else "",
|
||||
"ToolCalls": dict_tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
dict_list.append({ "Role": message.role.value, "Content": message.content })
|
||||
dict_list.append({"Role": message.role.value, "Content": message.content})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
tool_execute_result = { "result": message.content }
|
||||
content =json.dumps(tool_execute_result, ensure_ascii=False)
|
||||
dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id })
|
||||
tool_execute_result = {"result": message.content}
|
||||
content = json.dumps(tool_execute_result, ensure_ascii=False)
|
||||
dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id})
|
||||
else:
|
||||
dict_list.append({ "Role": message.role.value, "Content": message.content })
|
||||
dict_list.append({"Role": message.role.value, "Content": message.content})
|
||||
return dict_list
|
||||
|
||||
def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
|
||||
|
||||
tool_call = None
|
||||
tool_calls = []
|
||||
|
||||
for index, event in enumerate(resp):
|
||||
logging.debug("_handle_stream_chat_response, event: %s", event)
|
||||
|
||||
data_str = event['data']
|
||||
data_str = event["data"]
|
||||
data = json.loads(data_str)
|
||||
|
||||
choices = data.get('Choices', [])
|
||||
choices = data.get("Choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
delta = choice.get('Delta', {})
|
||||
message_content = delta.get('Content', '')
|
||||
finish_reason = choice.get('FinishReason', '')
|
||||
delta = choice.get("Delta", {})
|
||||
message_content = delta.get("Content", "")
|
||||
finish_reason = choice.get("FinishReason", "")
|
||||
|
||||
usage = data.get('Usage', {})
|
||||
prompt_tokens = usage.get('PromptTokens', 0)
|
||||
completion_tokens = usage.get('CompletionTokens', 0)
|
||||
usage = data.get("Usage", {})
|
||||
prompt_tokens = usage.get("PromptTokens", 0)
|
||||
completion_tokens = usage.get("CompletionTokens", 0)
|
||||
|
||||
response_tool_calls = delta.get('ToolCalls')
|
||||
if (response_tool_calls is not None):
|
||||
response_tool_calls = delta.get("ToolCalls")
|
||||
if response_tool_calls is not None:
|
||||
new_tool_calls = self._extract_response_tool_calls(response_tool_calls)
|
||||
if (len(new_tool_calls) > 0):
|
||||
if len(new_tool_calls) > 0:
|
||||
new_tool_call = new_tool_calls[0]
|
||||
if (tool_call is None): tool_call = new_tool_call
|
||||
elif (tool_call.id != new_tool_call.id):
|
||||
if tool_call is None:
|
||||
tool_call = new_tool_call
|
||||
elif tool_call.id != new_tool_call.id:
|
||||
tool_calls.append(tool_call)
|
||||
tool_call = new_tool_call
|
||||
else:
|
||||
tool_call.function.name += new_tool_call.function.name
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0):
|
||||
if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0:
|
||||
tool_calls.append(tool_call)
|
||||
tool_call = None
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=message_content,
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[])
|
||||
# rewrite content = "" while tool_call to avoid show content on web page
|
||||
if (len(tool_calls) > 0): assistant_prompt_message.content = ""
|
||||
|
||||
if len(tool_calls) > 0:
|
||||
assistant_prompt_message.content = ""
|
||||
|
||||
# add tool_calls to assistant_prompt_message
|
||||
if (finish_reason == 'tool_calls'):
|
||||
if finish_reason == "tool_calls":
|
||||
assistant_prompt_message.tool_calls = tool_calls
|
||||
tool_call = None
|
||||
tool_calls = []
|
||||
|
||||
if (len(finish_reason) > 0):
|
||||
if len(finish_reason) > 0:
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
delta_chunk = LLMResultChunkDelta(
|
||||
index=index,
|
||||
role=delta.get('Role', 'assistant'),
|
||||
role=delta.get("Role", "assistant"),
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
@ -212,8 +222,9 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
def _handle_chat_response(self, credentials, model, prompt_messages, response):
|
||||
usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens,
|
||||
response.Usage.CompletionTokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage()
|
||||
assistant_prompt_message.content = response.Choices[0].Message.Content
|
||||
result = LLMResult(
|
||||
@ -225,8 +236,13 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
if len(prompt_messages) == 0:
|
||||
return 0
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
@ -241,10 +257,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
@ -287,10 +300,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
return {
|
||||
InvokeError: [TencentCloudSDKException],
|
||||
}
|
||||
|
||||
def _extract_response_tool_calls(self,
|
||||
response_tool_calls: list[dict]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -300,17 +311,14 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
response_function = response_tool_call.get('Function', {})
|
||||
response_function = response_tool_call.get("Function", {})
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function.get('Name', ''),
|
||||
arguments=response_function.get('Arguments', '')
|
||||
name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "")
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.get('Id', 0),
|
||||
type='function',
|
||||
function=function
|
||||
id=response_tool_call.get("Id", 0), type="function", function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
return tool_calls
|
||||
|
||||
@ -19,14 +19,15 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Hunyuan text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -37,9 +38,9 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
if model != 'hunyuan-embedding':
|
||||
raise ValueError('Invalid model name')
|
||||
|
||||
if model != "hunyuan-embedding":
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
embeddings = []
|
||||
@ -47,9 +48,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
for input in texts:
|
||||
request = models.GetEmbeddingRequest()
|
||||
params = {
|
||||
"Input": input
|
||||
}
|
||||
params = {"Input": input}
|
||||
request.from_json_string(json.dumps(params))
|
||||
response = client.GetEmbedding(request)
|
||||
usage = response.Usage.TotalTokens
|
||||
@ -60,11 +59,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
|
||||
return result
|
||||
@ -79,22 +74,19 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": [{
|
||||
"Role": "user",
|
||||
"Content": "hello"
|
||||
}],
|
||||
"Messages": [{"Role": "user", "Content": "hello"}],
|
||||
"TopP": 1,
|
||||
"Temperature": 0,
|
||||
"Stream": False
|
||||
"Stream": False,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
client.ChatCompletions(req)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
def _setup_hunyuan_client(self, credentials):
|
||||
secret_id = credentials['secret_id']
|
||||
secret_key = credentials['secret_key']
|
||||
secret_id = credentials["secret_id"]
|
||||
secret_key = credentials["secret_key"]
|
||||
cred = credential.Credential(secret_id, secret_key)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
@ -102,7 +94,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
clientProfile.httpProfile = httpProfile
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
return client
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -114,10 +106,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -128,11 +117,11 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
@ -146,7 +135,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
return {
|
||||
InvokeError: [TencentCloudSDKException],
|
||||
}
|
||||
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -170,4 +159,4 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
# response = client.GetTokenCount(request)
|
||||
# num_tokens += response.TokenCount
|
||||
|
||||
return num_tokens
|
||||
return num_tokens
|
||||
|
||||
@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JinaProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -21,12 +20,9 @@ class JinaProvider(ModelProvider):
|
||||
|
||||
# Use `jina-embeddings-v2-base-en` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel):
|
||||
Model class for Jina rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -40,37 +47,32 @@ class JinaRerankModel(RerankModel):
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
|
||||
if base_url.endswith('/'):
|
||||
base_url = credentials.get("base_url", "https://api.jina.ai/v1")
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
base_url + '/rerank',
|
||||
json={
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n
|
||||
},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
base_url + "/rerank",
|
||||
json={"model": model, "query": query, "documents": docs, "top_n": top_n},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
for result in results["results"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
index=result["index"],
|
||||
text=result["document"]["text"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -81,7 +83,6 @@ class JinaRerankModel(RerankModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
@ -92,7 +93,7 @@ class JinaRerankModel(RerankModel):
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -105,23 +106,21 @@ class JinaRerankModel(RerankModel):
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
|
||||
}
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
||||
@ -14,19 +14,19 @@ class JinaTokenizer:
|
||||
with cls._lock:
|
||||
if cls._tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "tokenizer")
|
||||
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
return cls._tokenizer
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
||||
"""
|
||||
use jina tokenizer to get num tokens
|
||||
use jina tokenizer to get num tokens
|
||||
"""
|
||||
tokenizer = cls._get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_num_tokens(cls, text: str) -> int:
|
||||
return cls._get_num_tokens_by_jina_base(text)
|
||||
return cls._get_num_tokens_by_jina_base(text)
|
||||
|
||||
@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Jina text embedding model.
|
||||
"""
|
||||
api_base: str = 'https://api.jina.ai/v1'
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
api_base: str = "https://api.jina.ai/v1"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -38,29 +39,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['api_key']
|
||||
api_key = credentials["api_key"]
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
base_url = credentials.get('base_url', self.api_base)
|
||||
if base_url.endswith('/'):
|
||||
base_url = credentials.get("base_url", self.api_base)
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
url = base_url + '/embeddings'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
url = base_url + "/embeddings"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
def transform_jina_input_text(model, text):
|
||||
if model == 'jina-clip-v1':
|
||||
if model == "jina-clip-v1":
|
||||
return {"text": text}
|
||||
return text
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'input': [transform_jina_input_text(model, text) for text in texts]
|
||||
}
|
||||
data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
@ -70,7 +65,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp['detail']
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
@ -81,25 +76,20 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=[[
|
||||
float(data) for data in x['embedding']
|
||||
] for x in embeddings],
|
||||
usage=usage
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
@ -128,30 +118,18 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed: {e}')
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError,
|
||||
InvokeBadRequestError
|
||||
]
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
@ -165,10 +143,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -179,24 +154,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get('context_size'))
|
||||
}
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LeptonAIProvider(ModelProvider):
|
||||
|
||||
class LeptonAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -18,12 +18,9 @@ class LeptonAIProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-7b',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="llama2-7b", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -8,18 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_PREFIX_MAP = {
|
||||
'llama2-7b': 'llama2-7b',
|
||||
'gemma-7b': 'gemma-7b',
|
||||
'mistral-7b': 'mistral-7b',
|
||||
'mixtral-8x7b': 'mixtral-8x7b',
|
||||
'llama3-70b': 'llama3-70b',
|
||||
'llama2-13b': 'llama2-13b',
|
||||
}
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"llama2-7b": "llama2-7b",
|
||||
"gemma-7b": "gemma-7b",
|
||||
"mistral-7b": "mistral-7b",
|
||||
"mixtral-8x7b": "mixtral-8x7b",
|
||||
"llama3-70b": "llama3-70b",
|
||||
"llama2-13b": "llama2-13b",
|
||||
}
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
@ -29,6 +36,5 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1'
|
||||
|
||||
credentials["mode"] = "chat"
|
||||
credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1"
|
||||
|
||||
@ -52,29 +52,48 @@ from core.model_runtime.utils import helper
|
||||
|
||||
|
||||
class LocalAILanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
# tools is not supported yet
|
||||
return self._num_tokens_from_messages(prompt_messages, tools=tools)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for baichuan model
|
||||
LocalAI does not supports
|
||||
Calculate num tokens for baichuan model
|
||||
LocalAI does not supports
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
"""
|
||||
We could not determine which tokenizer to use, cause the model is customized.
|
||||
So we use gpt2 tokenizer to calculate the num tokens for convenience.
|
||||
We could not determine which tokenizer to use, cause the model is customized.
|
||||
So we use gpt2 tokenizer to calculate the num tokens for convenience.
|
||||
"""
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@ -87,10 +106,10 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
@ -142,30 +161,30 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
# calculate num tokens for function object
|
||||
num_tokens += tokens('name')
|
||||
num_tokens += tokens("name")
|
||||
num_tokens += tokens(tool.name)
|
||||
num_tokens += tokens('description')
|
||||
num_tokens += tokens("description")
|
||||
num_tokens += tokens(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += tokens('parameters')
|
||||
num_tokens += tokens('type')
|
||||
num_tokens += tokens("parameters")
|
||||
num_tokens += tokens("type")
|
||||
num_tokens += tokens(parameters.get("type"))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += tokens('properties')
|
||||
for key, value in parameters.get('properties').items():
|
||||
if "properties" in parameters:
|
||||
num_tokens += tokens("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += tokens(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += tokens(field_key)
|
||||
if field_key == 'enum':
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(enum_field)
|
||||
else:
|
||||
num_tokens += tokens(field_key)
|
||||
num_tokens += tokens(str(field_value))
|
||||
if 'required' in parameters:
|
||||
num_tokens += tokens('required')
|
||||
for required_field in parameters['required']:
|
||||
if "required" in parameters:
|
||||
num_tokens += tokens("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(required_field)
|
||||
|
||||
@ -180,102 +199,104 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
||||
UserPromptMessage(content='ping')
|
||||
], model_parameters={
|
||||
'max_tokens': 10,
|
||||
}, stop=[], stream=False)
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=[],
|
||||
stream=False,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
completion_model = None
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
if credentials["completion_type"] == "chat_completion":
|
||||
completion_model = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_model = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
name="temperature",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
)
|
||||
use_template="temperature",
|
||||
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
name="top_p",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
use_template="top_p",
|
||||
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
name="max_tokens",
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
use_template="max_tokens",
|
||||
min=1,
|
||||
max=2048,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
||||
),
|
||||
]
|
||||
|
||||
model_properties = {
|
||||
ModelPropertyKey.MODE: completion_model,
|
||||
} if completion_model else {}
|
||||
model_properties = (
|
||||
{
|
||||
ModelPropertyKey.MODE: completion_model,
|
||||
}
|
||||
if completion_model
|
||||
else {}
|
||||
)
|
||||
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048"))
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=rules
|
||||
parameter_rules=rules,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
kwargs = self._to_client_kwargs(credentials)
|
||||
# init model client
|
||||
client = OpenAI(**kwargs)
|
||||
|
||||
model_name = model
|
||||
completion_type = credentials['completion_type']
|
||||
completion_type = credentials["completion_type"]
|
||||
|
||||
extra_model_kwargs = {
|
||||
"timeout": 60,
|
||||
}
|
||||
if stop:
|
||||
extra_model_kwargs['stop'] = stop
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
extra_model_kwargs['functions'] = [
|
||||
helper.dump_model(tool) for tool in tools
|
||||
]
|
||||
extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools]
|
||||
|
||||
if completion_type == 'chat_completion':
|
||||
if completion_type == "chat_completion":
|
||||
result = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model_name,
|
||||
@ -283,36 +304,32 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
elif completion_type == 'completion':
|
||||
elif completion_type == "completion":
|
||||
result = client.completions.create(
|
||||
prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages),
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {completion_type}")
|
||||
|
||||
if stream:
|
||||
if completion_type == 'completion':
|
||||
if completion_type == "completion":
|
||||
return self._handle_completion_generate_stream_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
|
||||
)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if completion_type == 'completion':
|
||||
if completion_type == "completion":
|
||||
return self._handle_completion_generate_response(
|
||||
model=model, credentials=credentials, response=result,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, prompt_messages=prompt_messages
|
||||
)
|
||||
return self._handle_chat_generate_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||
@ -322,13 +339,13 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
:param credentials: credentials dict
|
||||
:return: client kwargs
|
||||
"""
|
||||
if not credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] += '/'
|
||||
if not credentials["server_url"].endswith("/"):
|
||||
credentials["server_url"] += "/"
|
||||
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": str(URL(credentials['server_url']) / 'v1'),
|
||||
"base_url": str(URL(credentials["server_url"]) / "v1"),
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
@ -349,7 +366,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments
|
||||
"arguments": message.tool_calls[0].function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
@ -359,11 +376,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
"content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
@ -374,27 +387,29 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Convert PromptMessage to completion prompts
|
||||
"""
|
||||
prompts = ''
|
||||
prompts = ""
|
||||
for message in messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
prompts += f'{message.content}\n'
|
||||
prompts += f"{message.content}\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
prompts += f'{message.content}\n'
|
||||
prompts += f"{message.content}\n"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
prompts += f'{message.content}\n'
|
||||
prompts += f"{message.content}\n"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return prompts
|
||||
|
||||
def _handle_completion_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
) -> LLMResult:
|
||||
def _handle_completion_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -411,18 +426,16 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
assistant_message = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message,
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[])
|
||||
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
@ -434,11 +447,14 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
tools: list[PromptMessageTool]) -> LLMResult:
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
tools: list[PromptMessageTool],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -459,16 +475,14 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message.content,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
@ -480,12 +494,15 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
return response
|
||||
|
||||
def _handle_completion_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
tools: list[PromptMessageTool]) -> Generator:
|
||||
full_response = ''
|
||||
def _handle_completion_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
tools: list[PromptMessageTool],
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
@ -494,17 +511,11 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
delta = chunk.choices[0]
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.text if delta.text else '',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response,
|
||||
tool_calls=[]
|
||||
)
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
|
||||
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
||||
@ -512,8 +523,12 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
@ -523,7 +538,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -539,12 +554,15 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
full_response += delta.text
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
tools: list[PromptMessageTool]) -> Generator:
|
||||
full_response = ''
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
tools: list[PromptMessageTool],
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
@ -552,7 +570,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
|
||||
continue
|
||||
|
||||
# check if there is a tool call in the response
|
||||
@ -564,22 +582,24 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=assistant_message_tool_calls
|
||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response,
|
||||
tool_calls=assistant_message_tool_calls
|
||||
content=full_response, tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
@ -589,7 +609,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -605,9 +625,9 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
|
||||
full_response += delta.delta.content
|
||||
|
||||
def _extract_response_tool_calls(self,
|
||||
response_function_calls: list[FunctionCall]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
def _extract_response_tool_calls(
|
||||
self, response_function_calls: list[FunctionCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -618,15 +638,10 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
if response_function_calls:
|
||||
for response_tool_call in response_function_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.name,
|
||||
arguments=response_tool_call.arguments
|
||||
name=response_tool_call.name, arguments=response_tool_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=0,
|
||||
type='function',
|
||||
function=function
|
||||
)
|
||||
tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
@ -651,15 +666,9 @@ class LocalAILanguageModel(LargeLanguageModel):
|
||||
ConflictError,
|
||||
NotFoundError,
|
||||
UnprocessableEntityError,
|
||||
PermissionDeniedError
|
||||
PermissionDeniedError,
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
AuthenticationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
ValueError
|
||||
]
|
||||
InvokeRateLimitError: [RateLimitError],
|
||||
InvokeAuthorizationError: [AuthenticationError],
|
||||
InvokeBadRequestError: [ValueError],
|
||||
}
|
||||
|
||||
@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel):
|
||||
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -43,45 +50,37 @@ class LocalaiRerankModel(RerankModel):
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
model_name = model
|
||||
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError('server_url is required')
|
||||
if not model_name:
|
||||
raise CredentialsValidateFailedError('model_name is required')
|
||||
|
||||
url = server_url
|
||||
headers = {
|
||||
'Authorization': f"Bearer {credentials.get('api_key')}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n
|
||||
}
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError("server_url is required")
|
||||
if not model_name:
|
||||
raise CredentialsValidateFailedError("model_name is required")
|
||||
|
||||
url = server_url
|
||||
headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n}
|
||||
|
||||
try:
|
||||
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
|
||||
response.raise_for_status()
|
||||
response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
for result in results["results"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
index=result["index"],
|
||||
text=result["document"]["text"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -92,7 +91,6 @@ class LocalaiRerankModel(RerankModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
@ -103,7 +101,7 @@ class LocalaiRerankModel(RerankModel):
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -116,21 +114,21 @@ class LocalaiRerankModel(RerankModel):
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={}
|
||||
model_properties={},
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@ -32,8 +32,8 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
|
||||
url = str(URL(credentials['server_url']) / "v1/audio/transcriptions")
|
||||
|
||||
url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions")
|
||||
data = {"model": model}
|
||||
files = {"file": file}
|
||||
|
||||
@ -42,7 +42,7 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
prepared_request = session.prepare_request(request)
|
||||
response = session.send(prepared_request)
|
||||
|
||||
if 'error' in response.json():
|
||||
if "error" in response.json():
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
return response.json()["text"]
|
||||
@ -58,7 +58,7 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
self._invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -66,36 +66,24 @@ class LocalAISpeech2text(Speech2TextModel):
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError
|
||||
],
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={},
|
||||
parameter_rules=[]
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
||||
@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Jina text embedding model.
|
||||
"""
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -37,39 +38,33 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
if len(texts) != 1:
|
||||
raise InvokeBadRequestError('Only one text is supported')
|
||||
raise InvokeBadRequestError("Only one text is supported")
|
||||
|
||||
server_url = credentials['server_url']
|
||||
server_url = credentials["server_url"]
|
||||
model_name = model
|
||||
if not server_url:
|
||||
raise CredentialsValidateFailedError('server_url is required')
|
||||
raise CredentialsValidateFailedError("server_url is required")
|
||||
if not model_name:
|
||||
raise CredentialsValidateFailedError('model_name is required')
|
||||
|
||||
url = server_url
|
||||
headers = {
|
||||
'Authorization': 'Bearer 123',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
raise CredentialsValidateFailedError("model_name is required")
|
||||
|
||||
data = {
|
||||
'model': model_name,
|
||||
'input': texts[0]
|
||||
}
|
||||
url = server_url
|
||||
headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"}
|
||||
|
||||
data = {"model": model_name, "input": texts[0]}
|
||||
|
||||
try:
|
||||
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
||||
response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10)
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
code = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
code = resp["error"]["code"]
|
||||
msg = resp["error"]["message"]
|
||||
if code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
|
||||
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
@ -79,23 +74,21 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
else:
|
||||
raise InvokeError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=[[
|
||||
float(data) for data in x['embedding']
|
||||
] for x in embeddings],
|
||||
usage=usage
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
@ -114,7 +107,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
@ -130,10 +123,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
features=[],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[]
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@ -145,32 +138,22 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError('Invalid credentials')
|
||||
raise CredentialsValidateFailedError("Invalid credentials")
|
||||
except InvokeConnectionError as e:
|
||||
raise CredentialsValidateFailedError(f'Invalid credentials: {e}')
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -182,10 +165,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -196,7 +176,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@ -17,42 +17,48 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxChatCompletion:
|
||||
"""
|
||||
Minimax Chat Completion API
|
||||
Minimax Chat Completion API
|
||||
"""
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
group_id: str,
|
||||
prompt_messages: list[MinimaxMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[dict[str, Any]],
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
user: str,
|
||||
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
"""
|
||||
generate chat completion
|
||||
generate chat completion
|
||||
"""
|
||||
if not api_key or not group_id:
|
||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
||||
|
||||
url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
|
||||
raise InvalidAPIKeyError("Invalid API key or group ID")
|
||||
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
|
||||
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
|
||||
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
|
||||
extra_kwargs['temperature'] = model_parameters['temperature']
|
||||
if "temperature" in model_parameters and type(model_parameters["temperature"]) == float:
|
||||
extra_kwargs["temperature"] = model_parameters["temperature"]
|
||||
|
||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
||||
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
|
||||
extra_kwargs["top_p"] = model_parameters["top_p"]
|
||||
|
||||
prompt = '你是一个什么都懂的专家'
|
||||
prompt = "你是一个什么都懂的专家"
|
||||
|
||||
role_meta = {
|
||||
'user_name': '我',
|
||||
'bot_name': '专家'
|
||||
}
|
||||
role_meta = {"user_name": "我", "bot_name": "专家"}
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one message is required')
|
||||
|
||||
raise BadRequestError("At least one message is required")
|
||||
|
||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||
if prompt_messages[0].content:
|
||||
prompt = prompt_messages[0].content
|
||||
@ -60,40 +66,39 @@ class MinimaxChatCompletion:
|
||||
|
||||
# check if there is a user message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one user message is required')
|
||||
|
||||
messages = [{
|
||||
'sender_type': message.role,
|
||||
'text': message.content,
|
||||
} for message in prompt_messages]
|
||||
raise BadRequestError("At least one user message is required")
|
||||
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
messages = [
|
||||
{
|
||||
"sender_type": message.role,
|
||||
"text": message.content,
|
||||
}
|
||||
for message in prompt_messages
|
||||
]
|
||||
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
body = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'prompt': prompt,
|
||||
'role_meta': role_meta,
|
||||
'stream': stream,
|
||||
**extra_kwargs
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"prompt": prompt,
|
||||
"role_meta": role_meta,
|
||||
"stream": stream,
|
||||
**extra_kwargs,
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
except Exception as e:
|
||||
raise InternalServerError(e)
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InternalServerError(response.text)
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_stream_chat_generate_response(response)
|
||||
return self._handle_chat_generate_response(response)
|
||||
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||
raise InternalServerError(msg)
|
||||
@ -110,65 +115,52 @@ class MinimaxChatCompletion:
|
||||
|
||||
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
||||
"""
|
||||
handle chat generate response
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
|
||||
code = response['base_resp']['status_code']
|
||||
msg = response['base_resp']['status_msg']
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(
|
||||
content=response['reply'],
|
||||
role=MinimaxMessage.Role.ASSISTANT.value
|
||||
)
|
||||
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': response['usage']['total_tokens'],
|
||||
'total_tokens': response['usage']['total_tokens']
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response['choices'][0]['finish_reason']
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
handle stream chat generate response
|
||||
"""
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode('utf-8')
|
||||
if line.startswith('data: '):
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
|
||||
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
|
||||
code = data['base_resp']['status_code']
|
||||
msg = data['base_resp']['status_msg']
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
msg = data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
if data['reply']:
|
||||
total_tokens = data['usage']['total_tokens']
|
||||
message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
content=''
|
||||
)
|
||||
message.usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': total_tokens,
|
||||
'total_tokens': total_tokens
|
||||
}
|
||||
message.stop_reason = data['choices'][0]['finish_reason']
|
||||
if data["reply"]:
|
||||
total_tokens = data["usage"]["total_tokens"]
|
||||
message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
|
||||
message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens}
|
||||
message.stop_reason = data["choices"][0]["finish_reason"]
|
||||
yield message
|
||||
return
|
||||
|
||||
choices = data.get('choices', [])
|
||||
choices = data.get("choices", [])
|
||||
if len(choices) == 0:
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice['delta']
|
||||
yield MinimaxMessage(
|
||||
content=message,
|
||||
role=MinimaxMessage.Role.ASSISTANT.value
|
||||
)
|
||||
message = choice["delta"]
|
||||
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
|
||||
@ -17,86 +17,83 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxChatCompletionPro:
|
||||
"""
|
||||
Minimax Chat Completion Pro API, supports function calling
|
||||
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
||||
Minimax Chat Completion Pro API, supports function calling
|
||||
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
||||
"""
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
group_id: str,
|
||||
prompt_messages: list[MinimaxMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[dict[str, Any]],
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
user: str,
|
||||
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
"""
|
||||
generate chat completion
|
||||
generate chat completion
|
||||
"""
|
||||
if not api_key or not group_id:
|
||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
||||
raise InvalidAPIKeyError("Invalid API key or group ID")
|
||||
|
||||
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
|
||||
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
|
||||
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
|
||||
extra_kwargs['temperature'] = model_parameters['temperature']
|
||||
if "temperature" in model_parameters and type(model_parameters["temperature"]) == float:
|
||||
extra_kwargs["temperature"] = model_parameters["temperature"]
|
||||
|
||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
||||
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
|
||||
extra_kwargs["top_p"] = model_parameters["top_p"]
|
||||
|
||||
if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool:
|
||||
extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info']
|
||||
|
||||
if model_parameters.get('plugin_web_search'):
|
||||
extra_kwargs['plugins'] = [
|
||||
'plugin_web_search'
|
||||
]
|
||||
if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool:
|
||||
extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"]
|
||||
|
||||
bot_setting = {
|
||||
'bot_name': '专家',
|
||||
'content': '你是一个什么都懂的专家'
|
||||
}
|
||||
if model_parameters.get("plugin_web_search"):
|
||||
extra_kwargs["plugins"] = ["plugin_web_search"]
|
||||
|
||||
reply_constraints = {
|
||||
'sender_type': 'BOT',
|
||||
'sender_name': '专家'
|
||||
}
|
||||
bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"}
|
||||
|
||||
reply_constraints = {"sender_type": "BOT", "sender_name": "专家"}
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one message is required')
|
||||
raise BadRequestError("At least one message is required")
|
||||
|
||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||
if prompt_messages[0].content:
|
||||
bot_setting['content'] = prompt_messages[0].content
|
||||
bot_setting["content"] = prompt_messages[0].content
|
||||
prompt_messages = prompt_messages[1:]
|
||||
|
||||
# check if there is a user message
|
||||
if len(prompt_messages) == 0:
|
||||
raise BadRequestError('At least one user message is required')
|
||||
raise BadRequestError("At least one user message is required")
|
||||
|
||||
messages = [message.to_dict() for message in prompt_messages]
|
||||
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
body = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'bot_setting': [bot_setting],
|
||||
'reply_constraints': reply_constraints,
|
||||
'stream': stream,
|
||||
**extra_kwargs
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"bot_setting": [bot_setting],
|
||||
"reply_constraints": reply_constraints,
|
||||
"stream": stream,
|
||||
**extra_kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body['functions'] = tools
|
||||
body['function_call'] = {'type': 'auto'}
|
||||
body["functions"] = tools
|
||||
body["function_call"] = {"type": "auto"}
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
except Exception as e:
|
||||
raise InternalServerError(e)
|
||||
|
||||
@ -123,78 +120,72 @@ class MinimaxChatCompletionPro:
|
||||
|
||||
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
||||
"""
|
||||
handle chat generate response
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
|
||||
code = response['base_resp']['status_code']
|
||||
msg = response['base_resp']['status_msg']
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(
|
||||
content=response['reply'],
|
||||
role=MinimaxMessage.Role.ASSISTANT.value
|
||||
)
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': response['usage']['total_tokens'],
|
||||
'total_tokens': response['usage']['total_tokens']
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response['choices'][0]['finish_reason']
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
handle stream chat generate response
|
||||
"""
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode('utf-8')
|
||||
if line.startswith('data: '):
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
|
||||
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
|
||||
code = data['base_resp']['status_code']
|
||||
msg = data['base_resp']['status_msg']
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
msg = data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
# final chunk
|
||||
if data['reply'] or data.get('usage'):
|
||||
total_tokens = data['usage']['total_tokens']
|
||||
minimax_message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
content=''
|
||||
)
|
||||
if data["reply"] or data.get("usage"):
|
||||
total_tokens = data["usage"]["total_tokens"]
|
||||
minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
|
||||
minimax_message.usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': total_tokens,
|
||||
'total_tokens': total_tokens
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": total_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
minimax_message.stop_reason = data['choices'][0]['finish_reason']
|
||||
minimax_message.stop_reason = data["choices"][0]["finish_reason"]
|
||||
|
||||
choices = data.get('choices', [])
|
||||
choices = data.get("choices", [])
|
||||
if len(choices) > 0:
|
||||
for choice in choices:
|
||||
message = choice['messages'][0]
|
||||
message = choice["messages"][0]
|
||||
# append function_call message
|
||||
if 'function_call' in message:
|
||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
function_call_message.function_call = message['function_call']
|
||||
if "function_call" in message:
|
||||
function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
function_call_message.function_call = message["function_call"]
|
||||
yield function_call_message
|
||||
|
||||
yield minimax_message
|
||||
return
|
||||
|
||||
# partial chunk
|
||||
choices = data.get('choices', [])
|
||||
choices = data.get("choices", [])
|
||||
if len(choices) == 0:
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice['messages'][0]
|
||||
message = choice["messages"][0]
|
||||
# append text message
|
||||
if 'text' in message:
|
||||
minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
if "text" in message:
|
||||
minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
yield minimax_message
|
||||
|
||||
@ -1,17 +1,22 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -34,18 +34,25 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
'abab6.5s-chat': MinimaxChatCompletionPro,
|
||||
'abab6.5-chat': MinimaxChatCompletionPro,
|
||||
'abab6-chat': MinimaxChatCompletionPro,
|
||||
'abab5.5s-chat': MinimaxChatCompletionPro,
|
||||
'abab5.5-chat': MinimaxChatCompletionPro,
|
||||
'abab5-chat': MinimaxChatCompletion
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5-chat": MinimaxChatCompletionPro,
|
||||
"abab6-chat": MinimaxChatCompletionPro,
|
||||
"abab5.5s-chat": MinimaxChatCompletionPro,
|
||||
"abab5.5-chat": MinimaxChatCompletionPro,
|
||||
"abab5-chat": MinimaxChatCompletion,
|
||||
}
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@ -53,82 +60,97 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
Validate credentials for Baichuan model
|
||||
"""
|
||||
if model not in self.model_apis:
|
||||
raise CredentialsValidateFailedError(f'Invalid model: {model}')
|
||||
raise CredentialsValidateFailedError(f"Invalid model: {model}")
|
||||
|
||||
if not credentials.get('minimax_api_key'):
|
||||
raise CredentialsValidateFailedError('Invalid API key')
|
||||
if not credentials.get("minimax_api_key"):
|
||||
raise CredentialsValidateFailedError("Invalid API key")
|
||||
|
||||
if not credentials.get("minimax_group_id"):
|
||||
raise CredentialsValidateFailedError("Invalid group ID")
|
||||
|
||||
if not credentials.get('minimax_group_id'):
|
||||
raise CredentialsValidateFailedError('Invalid group ID')
|
||||
|
||||
# ping
|
||||
instance = MinimaxChatCompletionPro()
|
||||
try:
|
||||
instance.generate(
|
||||
model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'],
|
||||
prompt_messages=[
|
||||
MinimaxMessage(content='ping', role='USER')
|
||||
],
|
||||
model=model,
|
||||
api_key=credentials["minimax_api_key"],
|
||||
group_id=credentials["minimax_group_id"],
|
||||
prompt_messages=[MinimaxMessage(content="ping", role="USER")],
|
||||
model_parameters={},
|
||||
tools=[], stop=[],
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=''
|
||||
user="",
|
||||
)
|
||||
except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for minimax model
|
||||
Calculate num tokens for minimax model
|
||||
|
||||
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
||||
to calculate the num tokens, so we use str() to convert the prompt to string
|
||||
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
||||
to calculate the num tokens, so we use str() to convert the prompt to string
|
||||
|
||||
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
||||
therefore, we use gpt2 tokenizer instead
|
||||
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
||||
therefore, we use gpt2 tokenizer instead
|
||||
"""
|
||||
messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages]
|
||||
return self._get_num_tokens_by_gpt2(str(messages_dict))
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
"""
|
||||
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
|
||||
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
|
||||
"""
|
||||
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
||||
|
||||
if tools:
|
||||
tools = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
tools = [
|
||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
|
||||
]
|
||||
|
||||
response = client.generate(
|
||||
model=model,
|
||||
api_key=credentials['minimax_api_key'],
|
||||
group_id=credentials['minimax_group_id'],
|
||||
api_key=credentials["minimax_api_key"],
|
||||
group_id=credentials["minimax_group_id"],
|
||||
prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages],
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
|
||||
return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model=model, prompt_messages=prompt_messages, credentials=credentials, response=response
|
||||
)
|
||||
return self._handle_chat_generate_response(
|
||||
model=model, prompt_messages=prompt_messages, credentials=credentials, response=response
|
||||
)
|
||||
|
||||
def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage:
|
||||
"""
|
||||
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
|
||||
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
|
||||
"""
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content)
|
||||
@ -136,26 +158,27 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
if prompt_message.tool_calls:
|
||||
message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
content=''
|
||||
)
|
||||
message.function_call={
|
||||
'name': prompt_message.tool_calls[0].function.name,
|
||||
'arguments': prompt_message.tool_calls[0].function.arguments
|
||||
message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="")
|
||||
message.function_call = {
|
||||
"name": prompt_message.tool_calls[0].function.name,
|
||||
"arguments": prompt_message.tool_calls[0].function.arguments,
|
||||
}
|
||||
return message
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
|
||||
elif isinstance(prompt_message, ToolPromptMessage):
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
|
||||
else:
|
||||
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
|
||||
raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported")
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=response.usage['prompt_tokens'],
|
||||
completion_tokens=response.usage['completion_tokens']
|
||||
)
|
||||
def _handle_chat_generate_response(
|
||||
self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage
|
||||
) -> LLMResult:
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=response.usage["prompt_tokens"],
|
||||
completion_tokens=response.usage["completion_tokens"],
|
||||
)
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -166,31 +189,33 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage],
|
||||
credentials: dict, response: Generator[MinimaxMessage, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[MinimaxMessage, None, None],
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials,
|
||||
prompt_tokens=message.usage['prompt_tokens'],
|
||||
completion_tokens=message.usage['completion_tokens']
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=message.usage["prompt_tokens"],
|
||||
completion_tokens=message.usage["completion_tokens"],
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
),
|
||||
)
|
||||
elif message.function_call:
|
||||
if 'name' not in message.function_call or 'arguments' not in message.function_call:
|
||||
if "name" not in message.function_call or "arguments" not in message.function_call:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
@ -199,15 +224,16 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=message.function_call['name'],
|
||||
arguments=message.function_call['arguments']
|
||||
content="",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=message.function_call["name"], arguments=message.function_call["arguments"]
|
||||
),
|
||||
)
|
||||
)]
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -217,10 +243,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
),
|
||||
)
|
||||
@ -236,22 +259,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
||||
|
||||
@ -4,32 +4,27 @@ from typing import Any
|
||||
|
||||
class MinimaxMessage:
|
||||
class Role(Enum):
|
||||
USER = 'USER'
|
||||
ASSISTANT = 'BOT'
|
||||
SYSTEM = 'SYSTEM'
|
||||
FUNCTION = 'FUNCTION'
|
||||
USER = "USER"
|
||||
ASSISTANT = "BOT"
|
||||
SYSTEM = "SYSTEM"
|
||||
FUNCTION = "FUNCTION"
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
stop_reason: str = ""
|
||||
function_call: dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||
return {
|
||||
'sender_type': 'BOT',
|
||||
'sender_name': '专家',
|
||||
'text': '',
|
||||
'function_call': self.function_call
|
||||
}
|
||||
|
||||
return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call}
|
||||
|
||||
return {
|
||||
'sender_type': self.role,
|
||||
'sender_name': '我' if self.role == 'USER' else '专家',
|
||||
'text': self.content,
|
||||
"sender_type": self.role,
|
||||
"sender_name": "我" if self.role == "USER" else "专家",
|
||||
"text": self.content,
|
||||
}
|
||||
|
||||
def __init__(self, content: str, role: str = 'USER') -> None:
|
||||
|
||||
def __init__(self, content: str, role: str = "USER") -> None:
|
||||
self.content = content
|
||||
self.role = role
|
||||
self.role = role
|
||||
|
||||
@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MinimaxProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
@ -19,12 +20,9 @@ class MinimaxProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `abab5.5-chat` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise CredentialsValidateFailedError(f'{ex}')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise CredentialsValidateFailedError(f"{ex}")
|
||||
|
||||
@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Minimax text embedding model.
|
||||
"""
|
||||
api_base: str = 'https://api.minimax.chat/v1/embeddings'
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
api_base: str = "https://api.minimax.chat/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -44,54 +45,43 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['minimax_api_key']
|
||||
group_id = credentials['minimax_group_id']
|
||||
if model != 'embo-01':
|
||||
raise ValueError('Invalid model name')
|
||||
api_key = credentials["minimax_api_key"]
|
||||
group_id = credentials["minimax_group_id"]
|
||||
if model != "embo-01":
|
||||
raise ValueError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
url = f'{self.api_base}?GroupId={group_id}'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
url = f"{self.api_base}?GroupId={group_id}"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
'model': 'embo-01',
|
||||
'texts': texts,
|
||||
'type': 'db'
|
||||
}
|
||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeServerUnavailableError(response.text)
|
||||
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
# check if there is an error
|
||||
if resp['base_resp']['status_code'] != 0:
|
||||
code = resp['base_resp']['status_code']
|
||||
msg = resp['base_resp']['status_msg']
|
||||
if resp["base_resp"]["status_code"] != 0:
|
||||
code = resp["base_resp"]["status_code"]
|
||||
msg = resp["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = resp['vectors']
|
||||
total_tokens = resp['total_tokens']
|
||||
embeddings = resp["vectors"]
|
||||
total_tokens = resp["total_tokens"]
|
||||
except InvalidAuthenticationError:
|
||||
raise InvalidAPIKeyError('Invalid api key')
|
||||
raise InvalidAPIKeyError("Invalid api key")
|
||||
except KeyError as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens)
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=usage
|
||||
)
|
||||
result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||
|
||||
return result
|
||||
|
||||
@ -119,9 +109,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvalidAPIKeyError:
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001:
|
||||
@ -148,25 +138,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -178,10 +160,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -192,7 +171,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@ -7,14 +7,19 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
|
||||
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
|
||||
# mistral dose not support user/stop arguments
|
||||
stop = []
|
||||
user = None
|
||||
@ -27,5 +32,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
|
||||
credentials["mode"] = "chat"
|
||||
credentials["endpoint_url"] = "https://api.mistral.ai/v1"
|
||||
|
||||
@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='open-mistral-7b',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
|
||||
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
@ -49,50 +55,50 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get('function_calling_type') == 'tool_call'
|
||||
else [],
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "tool_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
use_template='temperature',
|
||||
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
use_template='max_tokens',
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
use_template='top_p',
|
||||
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
credentials["mode"] = "chat"
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
|
||||
model_schema.features or []
|
||||
):
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
@ -107,19 +113,13 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
@ -129,14 +129,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append({
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments
|
||||
message_dict["tool_calls"].append(
|
||||
{
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
@ -162,21 +164,26 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||
name=response_tool_call["function"]["name"]
|
||||
if response_tool_call.get("function", {}).get("name")
|
||||
else "",
|
||||
arguments=response_tool_call["function"]["arguments"]
|
||||
if response_tool_call.get("function", {}).get("arguments")
|
||||
else "",
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function
|
||||
function=function,
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@ -186,11 +193,12 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
@ -201,12 +209,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
@ -220,9 +223,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
@ -244,9 +247,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
@ -255,21 +258,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
finish_reason="Non-JSON encountered.",
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
@ -277,19 +280,18 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif 'text' in choice:
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
@ -305,26 +307,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
content=""
|
||||
),
|
||||
)
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoonshotProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -19,12 +18,9 @@ class MoonshotProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='moonshot-v1-8k',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -8,20 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
|
||||
class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _update_endpoint_url(self, credentials: dict):
|
||||
|
||||
credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
|
||||
credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
|
||||
credentials["endpoint_url"] = "https://api.novita.ai/v3/openai"
|
||||
credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"}
|
||||
return credentials
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
self._add_custom_parameters(credentials, model)
|
||||
@ -29,21 +34,36 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials["mode"] = "chat"
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._generate(
|
||||
model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super().get_customizable_model_schema(model, cred_with_endpoint)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)
|
||||
|
||||
@ -20,12 +20,9 @@ class NovitaProvider(ModelProvider):
|
||||
|
||||
# Use `meta-llama/llama-3-8b-instruct` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -21,31 +21,36 @@ from core.model_runtime.utils import helper
|
||||
|
||||
class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_SUFFIX_MAP = {
|
||||
'fuyu-8b': 'vlm/adept/fuyu-8b',
|
||||
'mistralai/mistral-large': '',
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
||||
'mistralai/mixtral-8x22b-instruct-v0.1': '',
|
||||
'google/gemma-7b': '',
|
||||
'google/codegemma-7b': '',
|
||||
'snowflake/arctic':'',
|
||||
'meta/llama2-70b': '',
|
||||
'meta/llama3-8b-instruct': '',
|
||||
'meta/llama3-70b-instruct': '',
|
||||
'meta/llama-3.1-8b-instruct': '',
|
||||
'meta/llama-3.1-70b-instruct': '',
|
||||
'meta/llama-3.1-405b-instruct': '',
|
||||
'google/recurrentgemma-2b': '',
|
||||
'nvidia/nemotron-4-340b-instruct': '',
|
||||
'microsoft/phi-3-medium-128k-instruct':'',
|
||||
'microsoft/phi-3-mini-128k-instruct':''
|
||||
"fuyu-8b": "vlm/adept/fuyu-8b",
|
||||
"mistralai/mistral-large": "",
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": "",
|
||||
"mistralai/mixtral-8x22b-instruct-v0.1": "",
|
||||
"google/gemma-7b": "",
|
||||
"google/codegemma-7b": "",
|
||||
"snowflake/arctic": "",
|
||||
"meta/llama2-70b": "",
|
||||
"meta/llama3-8b-instruct": "",
|
||||
"meta/llama3-70b-instruct": "",
|
||||
"meta/llama-3.1-8b-instruct": "",
|
||||
"meta/llama-3.1-70b-instruct": "",
|
||||
"meta/llama-3.1-405b-instruct": "",
|
||||
"google/recurrentgemma-2b": "",
|
||||
"nvidia/nemotron-4-340b-instruct": "",
|
||||
"microsoft/phi-3-medium-128k-instruct": "",
|
||||
"microsoft/phi-3-mini-128k-instruct": "",
|
||||
}
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
prompt_messages = self._transform_prompt_messages(prompt_messages)
|
||||
stop = []
|
||||
@ -60,16 +65,14 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
for i, p in enumerate(prompt_messages):
|
||||
if isinstance(p, UserPromptMessage) and isinstance(p.content, list):
|
||||
content = p.content
|
||||
content_text = ''
|
||||
content_text = ""
|
||||
for prompt_content in content:
|
||||
if prompt_content.type == PromptMessageContentType.TEXT:
|
||||
content_text += prompt_content.data
|
||||
else:
|
||||
content_text += f' <img src="{prompt_content.data}" />'
|
||||
|
||||
prompt_message = UserPromptMessage(
|
||||
content=content_text
|
||||
)
|
||||
prompt_message = UserPromptMessage(content=content_text)
|
||||
prompt_messages[i] = prompt_message
|
||||
return prompt_messages
|
||||
|
||||
@ -78,15 +81,15 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._validate_credentials(model, credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: str) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
|
||||
if self.MODEL_SUFFIX_MAP[model]:
|
||||
credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}'
|
||||
credentials.pop('endpoint_url')
|
||||
else:
|
||||
credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1'
|
||||
credentials["mode"] = "chat"
|
||||
|
||||
credentials['stream_mode_delimiter'] = '\n'
|
||||
if self.MODEL_SUFFIX_MAP[model]:
|
||||
credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}"
|
||||
credentials.pop("endpoint_url")
|
||||
else:
|
||||
credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
credentials["stream_mode_delimiter"] = "\n"
|
||||
|
||||
def _validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -97,72 +100,67 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials.get('server_url')
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if endpoint_url and not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
server_url = credentials.get("server_url")
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
'model': model,
|
||||
'max_tokens': 5
|
||||
}
|
||||
data = {"model": model, "max_tokens": 5}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
data['messages'] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "ping"
|
||||
},
|
||||
data["messages"] = [
|
||||
{"role": "user", "content": "ping"},
|
||||
]
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
if "endpoint_url" in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / "chat" / "completions")
|
||||
elif "server_url" in credentials:
|
||||
endpoint_url = server_url
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data['prompt'] = 'ping'
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
data["prompt"] = "ping"
|
||||
if "endpoint_url" in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / "completions")
|
||||
elif "server_url" in credentials:
|
||||
endpoint_url = server_url
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 300)
|
||||
)
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, \
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
@ -176,57 +174,51 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept-Charset': 'utf-8',
|
||||
"Content-Type": "application/json",
|
||||
"Accept-Charset": "utf-8",
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if stream:
|
||||
headers['Accept'] = 'text/event-stream'
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials.get('server_url')
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if endpoint_url and not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
server_url = credentials.get("server_url")
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
**model_parameters
|
||||
}
|
||||
data = {"model": model, "stream": stream, **model_parameters}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
if "endpoint_url" in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / "chat" / "completions")
|
||||
elif "server_url" in credentials:
|
||||
endpoint_url = server_url
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data['prompt'] = 'ping'
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
data["prompt"] = "ping"
|
||||
if "endpoint_url" in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / "completions")
|
||||
elif "server_url" in credentials:
|
||||
endpoint_url = server_url
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
if function_calling_type == 'function_call':
|
||||
data['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
elif function_calling_type == 'tool_call':
|
||||
if function_calling_type == "function_call":
|
||||
data["functions"] = [
|
||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
|
||||
for tool in tools
|
||||
]
|
||||
elif function_calling_type == "tool_call":
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
@ -240,16 +232,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
if user:
|
||||
data["user"] = user
|
||||
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
|
||||
|
||||
if response.encoding is None or response.encoding == 'ISO-8859-1':
|
||||
response.encoding = 'utf-8'
|
||||
if response.encoding is None or response.encoding == "ISO-8859-1":
|
||||
response.encoding = "utf-8"
|
||||
|
||||
if not response.ok:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='mistralai/mixtral-8x7b-instruct-v0.1',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel):
|
||||
"""
|
||||
|
||||
def _sigmoid(self, logit: float) -> float:
|
||||
return 1/(1+exp(-logit))
|
||||
return 1 / (1 + exp(-logit))
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@ -60,9 +67,9 @@ class NvidiaRerankModel(RerankModel):
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['rankings']:
|
||||
index = result['index']
|
||||
logit = result['logit']
|
||||
for result in results["rankings"]:
|
||||
index = result["index"]
|
||||
logit = result["logit"]
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=docs[index],
|
||||
@ -110,5 +117,5 @@ class NvidiaRerankModel(RerankModel):
|
||||
InvokeServerUnavailableError: [requests.HTTPError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [requests.HTTPError],
|
||||
InvokeBadRequestError: [requests.RequestException]
|
||||
InvokeBadRequestError: [requests.RequestException],
|
||||
}
|
||||
|
||||
@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Nvidia text embedding model.
|
||||
"""
|
||||
api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings'
|
||||
models: list[str] = ['NV-Embed-QA']
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
||||
models: list[str] = ["NV-Embed-QA"]
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -37,32 +38,25 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['api_key']
|
||||
api_key = credentials["api_key"]
|
||||
if model not in self.models:
|
||||
raise InvokeBadRequestError('Invalid model name')
|
||||
raise InvokeBadRequestError("Invalid model name")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
url = self.api_base
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'input': texts[0],
|
||||
'input_type': 'query'
|
||||
}
|
||||
data = {"model": model, "input": texts[0], "input_type": "query"}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp['detail']
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
@ -72,23 +66,21 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
else:
|
||||
raise InvokeError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=[[
|
||||
float(data) for data in x['embedding']
|
||||
] for x in embeddings],
|
||||
usage=usage
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
@ -117,30 +109,20 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -152,10 +134,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -166,7 +145,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel):
|
||||
"""
|
||||
Model class for NVIDIA NIM large language model.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -6,6 +6,5 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIANIMProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
|
||||
@ -33,31 +33,29 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.command-r-plus",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"},
|
||||
"chatRequest": {
|
||||
"apiFormat": "COHERE",
|
||||
#"preambleOverride": "You are a helpful assistant.",
|
||||
#"message": "Hello!",
|
||||
#"chatHistory": [],
|
||||
# "preambleOverride": "You are a helpful assistant.",
|
||||
# "message": "Hello!",
|
||||
# "chatHistory": [],
|
||||
"maxTokens": 600,
|
||||
"isStream": False,
|
||||
"frequencyPenalty": 0,
|
||||
"presencePenalty": 0,
|
||||
"temperature": 1,
|
||||
"topP": 0.75
|
||||
}
|
||||
"topP": 0.75,
|
||||
},
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": "",
|
||||
}
|
||||
|
||||
|
||||
class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||
@ -100,11 +98,17 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
return False
|
||||
return feature["system"]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -118,22 +122,27 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
#print("model"+"*"*20)
|
||||
#print(model)
|
||||
#print("credentials"+"*"*20)
|
||||
#print(credentials)
|
||||
#print("model_parameters"+"*"*20)
|
||||
#print(model_parameters)
|
||||
#print("prompt_messages"+"*"*200)
|
||||
#print(prompt_messages)
|
||||
#print("tools"+"*"*20)
|
||||
#print(tools)
|
||||
# print("model"+"*"*20)
|
||||
# print(model)
|
||||
# print("credentials"+"*"*20)
|
||||
# print(credentials)
|
||||
# print("model_parameters"+"*"*20)
|
||||
# print(model_parameters)
|
||||
# print("prompt_messages"+"*"*200)
|
||||
# print(prompt_messages)
|
||||
# print("tools"+"*"*20)
|
||||
# print(tools)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -147,8 +156,13 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_characters(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -169,10 +183,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
@ -192,11 +203,17 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -218,10 +235,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8")
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
raise CredentialsValidateFailedError(
|
||||
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||
)
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
@ -230,12 +249,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8")
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
|
||||
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
@ -245,9 +264,9 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
chat_history = []
|
||||
system_prompts = []
|
||||
#if "meta.llama" in model:
|
||||
# if "meta.llama" in model:
|
||||
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
|
||||
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
|
||||
request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600)
|
||||
request_args["chatRequest"].update(model_parameters)
|
||||
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
|
||||
presence_penalty = model_parameters.get("presencePenalty", 0)
|
||||
@ -267,7 +286,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
if not valid_value:
|
||||
raise InvokeBadRequestError("Does not support function calling")
|
||||
if model.startswith("cohere"):
|
||||
#print("run cohere " * 10)
|
||||
# print("run cohere " * 10)
|
||||
for message in prompt_messages[:-1]:
|
||||
text = ""
|
||||
if isinstance(message.content, str):
|
||||
@ -279,37 +298,37 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
system_prompts.append(message.content)
|
||||
args = {"apiFormat": "COHERE",
|
||||
"preambleOverride": ' '.join(system_prompts),
|
||||
"message": prompt_messages[-1].content,
|
||||
"chatHistory": chat_history, }
|
||||
args = {
|
||||
"apiFormat": "COHERE",
|
||||
"preambleOverride": " ".join(system_prompts),
|
||||
"message": prompt_messages[-1].content,
|
||||
"chatHistory": chat_history,
|
||||
}
|
||||
request_args["chatRequest"].update(args)
|
||||
elif model.startswith("meta"):
|
||||
#print("run meta " * 10)
|
||||
# print("run meta " * 10)
|
||||
meta_messages = []
|
||||
for message in prompt_messages:
|
||||
text = message.content
|
||||
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
|
||||
args = {"apiFormat": "GENERIC",
|
||||
"messages": meta_messages,
|
||||
"numGenerations": 1,
|
||||
"topK": -1}
|
||||
args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1}
|
||||
request_args["chatRequest"].update(args)
|
||||
|
||||
if stream:
|
||||
request_args["chatRequest"]["isStream"] = True
|
||||
#print("final request" + "|" * 20)
|
||||
#print(request_args)
|
||||
# print("final request" + "|" * 20)
|
||||
# print(request_args)
|
||||
response = client.chat(request_args)
|
||||
#print(vars(response))
|
||||
# print(vars(response))
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
@ -320,9 +339,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.data.chat_response.text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
@ -341,8 +358,9 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@ -356,14 +374,12 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
events = response.data.events()
|
||||
for stream in events:
|
||||
chunk = json.loads(stream.data)
|
||||
#print(chunk)
|
||||
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||
# print(chunk)
|
||||
# chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||
|
||||
|
||||
|
||||
#for chunk in response:
|
||||
#for part in chunk.parts:
|
||||
#if part.function_call:
|
||||
# for chunk in response:
|
||||
# for part in chunk.parts:
|
||||
# if part.function_call:
|
||||
# assistant_prompt_message.tool_calls = [
|
||||
# AssistantPromptMessage.ToolCall(
|
||||
# id=part.function_call.name,
|
||||
@ -376,9 +392,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# ]
|
||||
|
||||
if "finishReason" not in chunk:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
||||
if model.startswith("cohere"):
|
||||
if chunk["text"]:
|
||||
assistant_prompt_message.content += chunk["text"]
|
||||
@ -389,10 +403,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
@ -409,8 +420,8 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk["finishReason"]),
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
@ -425,9 +436,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@ -457,5 +466,5 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
|
||||
@ -8,7 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCIGENAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
@ -21,14 +20,9 @@ class OCIGENAIProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `cohere.command-r-plus` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='cohere.command-r-plus',
|
||||
credentials=credentials
|
||||
)
|
||||
model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
||||
|
||||
|
||||
@ -21,29 +21,28 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.embed-english-light-v3.0",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"},
|
||||
"truncate": "NONE",
|
||||
"inputs": [""]
|
||||
"inputs": [""],
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": "",
|
||||
}
|
||||
|
||||
|
||||
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
@ -62,14 +61,13 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
@ -80,26 +78,16 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=inputs[i: i + max_chunks]
|
||||
model=model, credentials=credentials, texts=inputs[i : i + max_chunks]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
@ -125,6 +113,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
for text in texts:
|
||||
characters += len(text)
|
||||
return characters
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
@ -135,11 +124,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -157,10 +142,12 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
# initialize client
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8")
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
raise CredentialsValidateFailedError(
|
||||
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||
)
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
@ -169,7 +156,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8")
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
@ -195,10 +182,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
@ -209,7 +193,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -224,19 +208,9 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError],
|
||||
}
|
||||
|
||||
@ -121,9 +121,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
text = ""
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
break
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
@ -145,13 +143,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
stream=False,
|
||||
)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {ex.description}"
|
||||
)
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {str(ex)}"
|
||||
)
|
||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -201,9 +195,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [
|
||||
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||
]
|
||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
@ -216,14 +208,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
images = []
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,",
|
||||
"",
|
||||
@ -235,24 +223,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
data["images"] = images
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
|
||||
)
|
||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
|
||||
|
||||
response.encoding = "utf-8"
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(
|
||||
f"API request failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
@ -292,9 +272,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
@ -335,9 +313,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
@ -394,15 +370,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_tokens = chunk_json["eval_count"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||
prompt_messages[0].content
|
||||
)
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json["model"],
|
||||
@ -439,17 +411,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
images = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
|
||||
)
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)
|
||||
images.append(image_data)
|
||||
|
||||
message_dict = {"role": "user", "content": text, "images": images}
|
||||
@ -479,9 +445,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
||||
@ -502,9 +466,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get("context_size", 4096)
|
||||
),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
@ -568,9 +530,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
),
|
||||
default=(
|
||||
512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
|
||||
),
|
||||
default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128),
|
||||
min=-2,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
),
|
||||
@ -612,22 +572,23 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"
|
||||
en_US="Sets the size of the context window used to generate the next token. " "(Default: 2048)"
|
||||
),
|
||||
default=2048,
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_gpu',
|
||||
name="num_gpu",
|
||||
label=I18nObject(en_US="GPU Layers"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="The number of layers to offload to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
"As long as a model fits into one gpu it stays in one. "
|
||||
"It does not set the number of GPU(s). "),
|
||||
help=I18nObject(
|
||||
en_US="The number of layers to offload to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
"As long as a model fits into one gpu it stays in one. "
|
||||
"It does not set the number of GPU(s). "
|
||||
),
|
||||
min=-1,
|
||||
default=1
|
||||
default=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_thread",
|
||||
@ -688,8 +649,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
label=I18nObject(en_US="Format"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."
|
||||
en_US="the format to return a response in." " Currently the only accepted value is json."
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user