feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@ -0,0 +1,58 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class LLMAPIConnectionError(LLMError):
"""Raised when the LLM returns API connection error."""
description = "API Connection Error"
class LLMAPIUnavailableError(LLMError):
"""Raised when the LLM returns API unavailable error."""
description = "API Unavailable Error"
class LLMRateLimitError(LLMError):
"""Raised when the LLM returns rate limit error."""
description = "Rate Limit Error"
class LLMAuthorizationError(LLMError):
"""Raised when the LLM returns authorization error."""
description = "Authorization Error"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"

View File

@ -0,0 +1,293 @@
from typing import Optional
from langchain.callbacks.base import Callbacks
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
class ModelFactory:
@classmethod
def get_text_generation_model_from_model_config(cls, tenant_id: str,
model_config: dict,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
provider_name = model_config.get("provider")
model_name = model_config.get("name")
completion_params = model_config.get("completion_params", {})
return cls.get_text_generation_model(
tenant_id=tenant_id,
model_provider_name=provider_name,
model_name=model_name,
model_kwargs=ModelKwargs(
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1)
),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_text_generation_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None,
model_kwargs: Optional[ModelKwargs] = None,
streaming: bool = False,
callbacks: Callbacks = None) -> Optional[BaseLLM]:
"""
get text generation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:param model_kwargs:
:param streaming:
:param callbacks:
:return:
"""
is_default_model = False
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default System Reasoning Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
is_default_model = True
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init text generation model
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
try:
model_instance = model_class(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming,
callbacks=callbacks
)
except LLMBadRequestError as e:
if is_default_model:
raise LLMBadRequestError(f"Default model {model_name} is not available. "
f"Please check your model provider credentials.")
else:
raise e
if is_default_model:
model_instance.deduct_quota = False
return model_instance
@classmethod
def get_embedding_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
"""
get embedding model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Embedding Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init embedding model
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
"""
get speech to text model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Speech-to-Text Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init speech to text model
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
"""
get moderation model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init moderation model
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if not default_model:
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(model_type)
if model_list:
model_info = model_list[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=model_provider_name,
model_name=model_info['id']
)
db.session.add(default_model)
db.session.commit()
break
return default_model
@classmethod
def update_default_model(cls,
tenant_id: str,
model_type: ModelType,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
model_provider_name = ModelProviderFactory.get_provider_names()
if provider_name not in model_provider_name:
raise ValueError(f'Invalid provider name: {provider_name}')
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
model_list = model_provider.get_supported_model_list(model_type)
model_ids = [model['id'] for model in model_list]
if model_name not in model_ids:
raise ValueError(f'Invalid model name: {model_name}')
# get default model
default_model = db.session.query(TenantDefaultModel) \
.filter(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.value
).first()
if default_model:
# update default model
default_model.provider_name = provider_name
default_model.model_name = model_name
db.session.commit()
else:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.value,
provider_name=provider_name,
model_name=model_name,
)
db.session.add(default_model)
db.session.commit()
return default_model

View File

@ -0,0 +1,228 @@
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
and quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)

View File

@ -0,0 +1,22 @@
from abc import ABC
from typing import Any
from core.model_providers.providers.base import BaseModelProvider
class BaseProviderModel(ABC):
_client: Any
_model_provider: BaseModelProvider
def __init__(self, model_provider: BaseModelProvider, client: Any):
self._model_provider = model_provider
self._client = client
@property
def client(self):
return self._client
@property
def model_provider(self):
return self._model_provider

View File

@ -0,0 +1,78 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
deployment=name,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseEmbedding(BaseProviderModel):
name: str
type: ModelType = ModelType.EMBEDDINGS
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -0,0 +1,35 @@
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class MinimaxEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = MiniMaxEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,72 @@
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class OpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
max_retries=1,
**credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,36 @@
import decimal
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class ReplicateEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ReplicateEmbeddings(
model=name + ':' + credentials.get('model_version'),
replicate_api_token=credentials.get('replicate_api_token')
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,53 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pydantic import BaseModel
class LLMRunResult(BaseModel):
content: str
prompt_tokens: int
completion_tokens: int
class MessageType(enum.Enum):
HUMAN = 'human'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
return lc_messages
def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
return prompt_messages
def str_to_prompt_messages(texts: list[str]):
prompt_messages = []
for text in texts:
prompt_messages.append(PromptMessage(content=text))
return prompt_messages

View File

@ -0,0 +1,59 @@
import enum
from typing import Optional, TypeVar, Generic
from langchain.load.serializable import Serializable
from pydantic import BaseModel
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class ModelType(enum.Enum):
TEXT_GENERATION = 'text-generation'
EMBEDDINGS = 'embeddings'
SPEECH_TO_TEXT = 'speech2text'
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
@staticmethod
def value_of(value):
for member in ModelType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ModelKwargs(BaseModel):
max_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
presence_penalty: Optional[float]
frequency_penalty: Optional[float]
class KwargRuleType(enum.Enum):
STRING = 'string'
INTEGER = 'integer'
FLOAT = 'float'
T = TypeVar('T')
class KwargRule(Generic[T], BaseModel):
enabled: bool = True
min: Optional[T] = None
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
class ModelKwargsRules(BaseModel):
max_tokens: KwargRule = KwargRule[int](enabled=False)
temperature: KwargRule = KwargRule[float](enabled=False)
top_p: KwargRule = KwargRule[float](enabled=False)
presence_penalty: KwargRule = KwargRule[float](enabled=False)
frequency_penalty: KwargRule = KwargRule[float](enabled=False)

View File

@ -0,0 +1,10 @@
from enum import Enum
class ProviderQuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought'

View File

@ -0,0 +1,107 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import anthropic
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class AnthropicModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatAnthropic(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
default_request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, anthropic.APIConnectionError):
logging.warning("Failed to connect to Anthropic API.")
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
elif isinstance(ex, anthropic.RateLimitError):
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
elif isinstance(ex, anthropic.AuthenticationError):
return LLMAuthorizationError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.BadRequestError):
return LLMBadRequestError(f"Anthropic: {ex.message}")
elif isinstance(ex, anthropic.APIStatusError):
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,177 @@
import decimal
import logging
from functools import wraps
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name == 'text-davinci-003':
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name == 'text-davinci-003':
client = EnhanceAzureOpenAI(
deployment_name=self.name,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceAzureChatOpenAI(
deployment_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
request_timeout=60,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_api_key=self.credentials.get('openai_api_key'),
openai_api_base=self.credentials.get('openai_api_base'),
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003':
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,269 @@
from abc import abstractmethod
from typing import List, Optional, Any, Union
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM
class BaseLLM(BaseProviderModel):
model_mode: ModelMode = ModelMode.COMPLETION
name: str
model_kwargs: ModelKwargs
credentials: dict
streaming: bool = False
type: ModelType = ModelType.TEXT_GENERATION
deduct_quota: bool = True
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.name = name
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
max_tokens=None,
temperature=None,
top_p=None,
presence_penalty=None,
frequency_penalty=None
)
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
self.streaming = streaming
if streaming:
default_callback = DifyStreamingStdOutCallbackHandler()
else:
default_callback = DifyStdOutCallbackHandler()
if not callbacks:
callbacks = [default_callback]
else:
callbacks.append(default_callback)
self.callbacks = callbacks
client = self._init_client()
super().__init__(model_provider, client)
@abstractmethod
def _init_client(self) -> Any:
raise NotImplementedError
def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMRunResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
if not callbacks:
callbacks = self.callbacks
else:
callbacks.extend(self.callbacks)
if 'fake_response' in kwargs and kwargs['fake_response']:
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=kwargs['fake_response'],
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
result = fake_llm.generate([prompts])
else:
try:
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
**kwargs
)
except Exception as ex:
raise self.handle_exceptions(ex)
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
response=completion_content,
num_token_func=self.get_num_tokens,
streaming=self.streaming,
callbacks=callbacks
)
fake_llm.generate([prompts])
if result.llm_output and result.llm_output['token_usage']:
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
completion_tokens = result.llm_output['token_usage']['completion_tokens']
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
if self.deduct_quota:
self.model_provider.deduct_quota(total_tokens)
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
)
@abstractmethod
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
"""
get token price.
:param tokens:
:param message_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_currency(self):
"""
get token currency.
:return:
"""
raise NotImplementedError
def get_model_kwargs(self):
return self.model_kwargs
def set_model_kwargs(self, model_kwargs: ModelKwargs):
self.model_kwargs = model_kwargs
self._set_model_kwargs(model_kwargs)
@abstractmethod
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
"""
Handle llm run exceptions.
:param ex:
:return:
"""
raise NotImplementedError
def add_callbacks(self, callbacks: Callbacks):
"""
Add callbacks to client.
:param callbacks:
:return:
"""
if not self.client.callbacks:
self.client.callbacks = callbacks
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if len(messages) == 0:
raise ValueError("prompt must not be empty.")
if not model_mode:
model_mode = self.model_mode
if model_mode == ModelMode.COMPLETION:
return messages[0].content
else:
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""
convert model kwargs to provider model kwargs.
:param model_rules:
:param model_kwargs:
:return:
"""
model_kwargs_input = {}
for key, value in model_kwargs.dict().items():
rule = getattr(model_rules, key)
if not rule.enabled:
continue
if rule.alias:
key = rule.alias
if rule.default is not None and value is None:
value = rule.default
if rule.min is not None:
value = max(value, rule.min)
if rule.max is not None:
value = min(value, rule.max)
model_kwargs_input[key] = value
return model_kwargs_input

View File

@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'),
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,82 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class HuggingfaceHubModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
else:
client = HuggingFaceHub(
repo_id=self.name,
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,70 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
model=self.name,
model_kwargs={
'stream': False
},
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,219 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'text-davinci-003', # 4,097 tokens
]
CHAT_MODELS = [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
]
MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
if name in COMPLETION_MODELS:
self.model_mode = ModelMode.COMPLETION
else:
self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.name in COMPLETION_MODELS:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials,
**provider_model_kwargs
)
else:
# Fine-tuning is currently only available for the following base models:
# davinci, curie, babbage, and ada.
# This means that except for the fixed `completion` model,
# all other fine-tuned models are `completion` models.
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
**self.credentials
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
if self.name == 'gpt-4' \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p'),
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True
# def is_model_valid_or_raise(self):
# """
# check is a valid model.
#
# :return:
# """
# credentials = self._model_provider.get_credentials()
#
# try:
# result = openai.Model.retrieve(
# id=self.name,
# api_key=credentials.get('openai_api_key'),
# request_timeout=60
# )
#
# if 'id' not in result or result['id'] != self.name:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
# except openai.error.OpenAIError as e:
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
# except Exception as e:
# logging.exception("OpenAI Model retrieve failed.")
# raise e

View File

@ -0,0 +1,103 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from replicate.exceptions import ReplicateError, ModelError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.error import LLMBadRequestError
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class ReplicateModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return EnhanceReplicate(
model=self.name + ':' + self.credentials.get('model_version'),
input=provider_model_kwargs,
streaming=self.streaming,
replicate_api_token=self.credentials.get('replicate_api_token'),
callbacks=self.callbacks,
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
extra_kwargs = {}
if isinstance(prompts, list):
system_messages = [message for message in messages if message.type == 'system']
if system_messages:
system_message = system_messages[0]
extra_kwargs['system_prompt'] = system_message.content
prompts = [message for message in messages if message.type != 'system']
prompts = get_buffer_string(prompts)
# The maximum length the generated tokens can have.
# Corresponds to the length of the input prompt + max_new_tokens.
if 'max_length' in self._client.input:
self._client.input['max_length'] = min(
self._client.input['max_length'] + self.get_num_tokens(messages),
self.model_rules.max_tokens.max
)
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, list):
prompts = get_buffer_string(prompts)
return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,73 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
class SparkModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, SparkError):
return LLMBadRequestError(f"Spark: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,77 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from requests import HTTPError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
class TongyiModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
del provider_model_kwargs['max_tokens']
return EnhanceTongyi(
model_name=self.name,
max_retries=1,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
del provider_model_kwargs['max_tokens']
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ValueError, HTTPError)):
return LLMBadRequestError(f"Tongyi: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,92 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Wenxin(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,48 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
def run(self, text):
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseSpeech2Text(BaseProviderModel):
name: str
type: ModelType = ModelType.SPEECH_TO_TEXT
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, file):
try:
return self._run(file)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, file):
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -0,0 +1,47 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from core.model_providers.providers.base import BaseModelProvider
class OpenAIWhisper(BaseSpeech2Text):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Audio, name)
def _run(self, file):
credentials = self.model_provider.get_model_credentials(
model_name=self.name,
model_type=self.type
)
return self._client.transcribe(
model=self.name,
file=file,
api_key=credentials.get('openai_api_key'),
api_base=credentials.get('openai_api_base'),
organization=credentials.get('openai_organization'),
)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

View File

@ -0,0 +1,224 @@
import json
import logging
from json import JSONDecodeError
from typing import Type, Optional
import anthropic
from flask import current_app
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from models.provider import ProviderType
class AnthropicProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'anthropic'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = AnthropicModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'anthropic_api_key' not in credentials:
raise CredentialsValidateFailedError('Anthropic API Key must be provided.')
try:
credential_kwargs = {
'anthropic_api_key': credentials['anthropic_api_key']
}
if 'anthropic_api_url' in credentials:
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
chat_llm = ChatAnthropic(
model='claude-instant-1',
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise CredentialsValidateFailedError(str(ex))
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'anthropic_api_url': None,
'anthropic_api_key': None
}
if credentials['anthropic_api_key']:
credentials['anthropic_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['anthropic_api_key']
)
if obfuscated:
credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key'])
if 'anthropic_api_url' not in credentials:
credentials['anthropic_api_url'] = None
return credentials
else:
if hosted_model_providers.anthropic:
return {
'anthropic_api_url': hosted_model_providers.anthropic.api_base,
'anthropic_api_key': hosted_model_providers.anthropic.api_key,
}
else:
return {
'anthropic_api_url': None,
'anthropic_api_key': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.anthropic:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.anthropic and \
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
return True
return False
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
if hosted_model_providers.anthropic \
and hosted_model_providers.anthropic.paid_enabled:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,387 @@
import json
import logging
from json import JSONDecodeError
from typing import Type
import openai
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
from extensions.ext_database import db
from models.provider import ProviderType, ProviderModel, ProviderQuotaType
BASE_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k',
'text-davinci-003',
'text-embedding-ada-002',
]
class AzureOpenAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'azure_openai'
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
# convert old provider config to provider models
self._convert_provider_config_to_model_config()
if self.provider.provider_type == ProviderType.CUSTOM.value:
# get configurable provider models
provider_models = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
model_list = []
for provider_model in provider_models:
model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k',
]:
model_dict['features'] = [
ModelFeature.AGENT_THOUGHT.value
]
model_list.append(model_dict)
else:
model_list = self._get_fixed_model_list(model_type)
return model_list
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4',
'name': 'gpt-4',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
}
]
if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
return models
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text-embedding-ada-002',
'name': 'text-embedding-ada-002'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = AzureOpenAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = AzureOpenAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
base_model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-35-turbo': 4096,
'gpt-35-turbo-16k': 16384,
'text-davinci-003': 4097,
}
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
if 'openai_api_base' not in credentials:
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
if credentials['base_model_name'] not in BASE_MODELS:
raise CredentialsValidateFailedError('Base Model Name is invalid')
if model_type == ModelType.TEXT_GENERATION:
try:
client = EnhanceAzureChatOpenAI(
deployment_name=model_name,
temperature=0,
max_tokens=15,
request_timeout=10,
openai_api_type='azure',
openai_api_version='2023-07-01-preview',
openai_api_key=credentials['openai_api_key'],
openai_api_base=credentials['openai_api_base'],
)
client.generate([[HumanMessage(content='hi!')]])
except openai.error.OpenAIError as e:
raise CredentialsValidateFailedError(
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Azure OpenAI Model retrieve failed.")
raise e
elif model_type == ModelType.EMBEDDINGS:
try:
client = OpenAIEmbeddings(
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
deployment=model_name,
chunk_size=16,
max_retries=1,
openai_api_key=credentials['openai_api_key'],
openai_api_base=credentials['openai_api_base']
)
client.embed_query('hi')
except openai.error.OpenAIError as e:
logging.exception("Azure OpenAI Model check error.")
raise CredentialsValidateFailedError(
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Azure OpenAI Model retrieve failed.")
raise e
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type == ProviderType.CUSTOM.value:
# convert old provider config to provider models
self._convert_provider_config_to_model_config()
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'openai_api_base': '',
'openai_api_key': '',
'base_model_name': ''
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['openai_api_key']:
credentials['openai_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['openai_api_key']
)
if obfuscated:
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
return credentials
else:
if hosted_model_providers.azure_openai:
return {
'openai_api_base': hosted_model_providers.azure_openai.api_base,
'openai_api_key': hosted_model_providers.azure_openai.api_key,
'base_model_name': model_name
}
else:
return {
'openai_api_base': None,
'openai_api_key': None,
'base_model_name': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.azure_openai:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.azure_openai \
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
return True
return False
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}
def _convert_provider_config_to_model_config(self):
if self.provider.provider_type == ProviderType.CUSTOM.value \
and self.provider.is_valid \
and self.provider.encrypted_config:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'openai_api_base': '',
'openai_api_key': '',
'base_model_name': ''
}
self._add_provider_model(
model_name='gpt-35-turbo',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-35-turbo-16k',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='gpt-4',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-davinci-003',
model_type=ModelType.TEXT_GENERATION,
provider_credentials=credentials
)
self._add_provider_model(
model_name='text-embedding-ada-002',
model_type=ModelType.EMBEDDINGS,
provider_credentials=credentials
)
self.provider.encrypted_config = None
db.session.commit()
def _add_provider_model(self, model_name: str, model_type: ModelType, provider_credentials: dict):
credentials = provider_credentials.copy()
credentials['base_model_name'] = model_name
provider_model = ProviderModel(
tenant_id=self.provider.tenant_id,
provider_name=self.provider.provider_name,
model_name=model_name,
model_type=model_type.value,
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_model)
db.session.commit()

View File

@ -0,0 +1,283 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Type, Optional
from flask import current_app
from pydantic import BaseModel
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
from extensions.ext_database import db
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.entity.provider import ProviderQuotaUnit
from core.model_providers.rules import provider_rules
from models.provider import Provider, ProviderType, ProviderModel
class BaseModelProvider(BaseModel, ABC):
provider: Provider
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def provider_name(self):
"""
Returns the name of a provider.
"""
raise NotImplementedError
def get_rules(self):
"""
Returns the rules of a provider.
"""
return provider_rules[self.provider_name]
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
rules = self.get_rules()
if 'custom' not in rules['support_provider_types']:
return self._get_fixed_model_list(model_type)
if 'model_flexibility' not in rules:
return self._get_fixed_model_list(model_type)
if rules['model_flexibility'] == 'fixed':
return self._get_fixed_model_list(model_type)
# get configurable provider models
provider_models = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
"""
get supported model object list for use.
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""
get specific model class.
:param model_type:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
check provider credentials valid.
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
"""
encrypt provider credentials for save.
:param tenant_id:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
raise NotImplementedError
@classmethod
@abstractmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
raise NotImplementedError
@classmethod
def is_provider_type_system_supported(cls) -> bool:
return current_app.config['EDITION'] == 'CLOUD'
def check_quota_over_limit(self):
"""
check provider quota over limit.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
provider = db.session.query(Provider).filter(
db.and_(
Provider.id == self.provider.id,
Provider.is_valid == True,
Provider.quota_limit > Provider.quota_used
)
).first()
if not provider:
raise QuotaExceededError()
def deduct_quota(self, used_tokens: int = 0) -> None:
"""
deduct available quota when provider type is system or paid.
:return:
"""
if self.provider.provider_type != ProviderType.SYSTEM.value:
return
rules = self.get_rules()
if 'system' not in rules['support_provider_types']:
return
if not self.should_deduct_quota():
return
if 'system_config' not in rules:
quota_unit = ProviderQuotaUnit.TIMES.value
elif 'quota_unit' not in rules['system_config']:
quota_unit = ProviderQuotaUnit.TIMES.value
else:
quota_unit = rules['system_config']['quota_unit']
if quota_unit == ProviderQuotaUnit.TOKENS.value:
used_quota = used_tokens
else:
used_quota = 1
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name,
Provider.provider_type == self.provider.provider_type,
Provider.quota_type == self.provider.quota_type,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
def should_deduct_quota(self):
return False
def update_last_used(self) -> None:
"""
update last used time.
:return:
"""
db.session.query(Provider).filter(
Provider.tenant_id == self.provider.tenant_id,
Provider.provider_name == self.provider.provider_name
).update({'last_used': datetime.utcnow()})
db.session.commit()
def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.
:return:
"""
return None
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.
:param model_name:
:param model_type:
:return:
"""
provider_model = db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == self.provider.tenant_id,
ProviderModel.provider_name == self.provider.provider_name,
ProviderModel.model_name == model_name,
ProviderModel.model_type == model_type.value,
ProviderModel.is_valid == True
).first()
if not provider_model:
raise LLMBadRequestError(f"The model {model_name} does not exist. "
f"Please check the configuration.")
return provider_model
class CredentialsValidateFailedError(Exception):
pass

View File

@ -0,0 +1,157 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
class ChatGLMProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'chatglm'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ChatGLMModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'chatglm-6b': 2000,
'chatglm2-6b': 32000,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_base' not in credentials:
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
try:
credential_kwargs = {
'endpoint_url': credentials['api_base']
}
llm = ChatGLM(
max_token=10,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_base': None
}
if credentials['api_base']:
credentials['api_base'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_base']
)
if obfuscated:
credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,76 @@
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
class HostedOpenAI(BaseModel):
api_base: str = None
api_organization: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the openai hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedAzureOpenAI(BaseModel):
api_base: str
api_key: str
quota_limit: int = 0
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
class HostedAnthropic(BaseModel):
api_base: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
class HostedModelProviders(BaseModel):
openai: Optional[HostedOpenAI] = None
azure_openai: Optional[HostedAzureOpenAI] = None
anthropic: Optional[HostedAnthropic] = None
hosted_model_providers = HostedModelProviders()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("HOSTED_OPENAI_ENABLED"):
hosted_model_providers.openai = HostedOpenAI(
api_base=app.config.get("HOSTED_OPENAI_API_BASE"),
api_organization=app.config.get("HOSTED_OPENAI_API_ORGANIZATION"),
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
)
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
hosted_model_providers.azure_openai = HostedAzureOpenAI(
api_base=app.config.get("HOSTED_AZURE_OPENAI_API_BASE"),
api_key=app.config.get("HOSTED_AZURE_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT"),
)
if app.config.get("HOSTED_ANTHROPIC_ENABLED"):
hosted_model_providers.anthropic = HostedAnthropic(
api_base=app.config.get("HOSTED_ANTHROPIC_API_BASE"),
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
)

View File

@ -0,0 +1,183 @@
import json
from typing import Type
from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from models.provider import ProviderType
class HuggingfaceHubProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'huggingface_hub'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if model_type != ModelType.TEXT_GENERATION:
raise NotImplementedError
if 'huggingfacehub_api_type' not in credentials \
or credentials['huggingfacehub_api_type'] not in ['hosted_inference_api', 'inference_endpoints']:
raise CredentialsValidateFailedError('Hugging Face Hub API Type invalid, '
'must be hosted_inference_api or inference_endpoints.')
if 'huggingfacehub_api_token' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub API Token must be provided.')
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
try:
hfapi.whoami()
except Exception:
raise CredentialsValidateFailedError("Invalid API Token.")
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
try:
llm = HuggingFaceEndpoint(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation",
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
try:
model_info = hfapi.model_info(repo_id=model_name)
if not model_info:
raise ValueError(f'Model {model_name} not found.')
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['huggingfacehub_api_token'] = encrypter.encrypt_token(tenant_id, credentials['huggingfacehub_api_token'])
if credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
model_info = hfapi.model_info(repo_id=model_name)
if not model_info:
raise ValueError(f'Model {model_name} not found.')
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
credentials['task_type'] = model_info.pipeline_tag
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'huggingfacehub_api_token': None,
'task_type': None
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['huggingfacehub_api_token']
)
if obfuscated:
credentials['huggingfacehub_api_token'] = encrypter.obfuscated_token(credentials['huggingfacehub_api_token'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@ -0,0 +1,179 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import Minimax
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType, ProviderQuotaType
class MinimaxProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'minimax'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'embo-01',
'name': 'embo-01',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = MinimaxModel
elif model_type == ModelType.EMBEDDINGS:
model_class = MinimaxEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'abab5.5-chat': 16384,
'abab5-chat': 6144,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'minimax_group_id' not in credentials:
raise CredentialsValidateFailedError('MiniMax Group ID must be provided.')
if 'minimax_api_key' not in credentials:
raise CredentialsValidateFailedError('MiniMax API Key must be provided.')
try:
credential_kwargs = {
'minimax_group_id': credentials['minimax_group_id'],
'minimax_api_key': credentials['minimax_api_key'],
}
llm = Minimax(
model='abab5.5-chat',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['minimax_api_key'] = encrypter.encrypt_token(tenant_id, credentials['minimax_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'minimax_group_id': None,
'minimax_api_key': None,
}
if credentials['minimax_api_key']:
credentials['minimax_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['minimax_api_key']
)
if obfuscated:
credentials['minimax_api_key'] = encrypter.obfuscated_token(credentials['minimax_api_key'])
return credentials
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,289 @@
import json
import logging
from json import JSONDecodeError
from typing import Type, Optional
from flask import current_app
from openai.error import AuthenticationError, OpenAIError
import openai
from core.helper import encrypter
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
from models.provider import ProviderType, ProviderQuotaType
class OpenAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'openai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4',
'name': 'gpt-4',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
}
]
if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
return models
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text-embedding-ada-002',
'name': 'text-embedding-ada-002'
}
]
elif model_type == ModelType.SPEECH_TO_TEXT:
return [
{
'id': 'whisper-1',
'name': 'whisper-1'
}
]
elif model_type == ModelType.MODERATION:
return [
{
'id': 'text-moderation-stable',
'name': 'text-moderation-stable'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = OpenAIEmbedding
elif model_type == ModelType.MODERATION:
model_class = OpenAIModeration
elif model_type == ModelType.SPEECH_TO_TEXT:
model_class = OpenAIWhisper
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'openai_api_key' not in credentials:
raise CredentialsValidateFailedError('OpenAI API key is required')
try:
credentials_kwargs = {
"api_key": credentials['openai_api_key']
}
if 'openai_api_base' in credentials and credentials['openai_api_base']:
credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1'
if 'openai_organization' in credentials:
credentials_kwargs['organization'] = credentials['openai_organization']
openai.ChatCompletion.create(
messages=[{"role": "user", "content": 'ping'}],
model='gpt-3.5-turbo',
timeout=10,
request_timeout=(5, 30),
max_tokens=20,
**credentials_kwargs
)
except (AuthenticationError, OpenAIError) as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'openai_api_base': None,
'openai_api_key': self.provider.encrypted_config,
'openai_organization': None
}
if credentials['openai_api_key']:
credentials['openai_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['openai_api_key']
)
if obfuscated:
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
if 'openai_api_base' not in credentials or not credentials['openai_api_base']:
credentials['openai_api_base'] = None
else:
credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1'
if 'openai_organization' not in credentials:
credentials['openai_organization'] = None
return credentials
else:
if hosted_model_providers.openai:
return {
'openai_api_base': hosted_model_providers.openai.api_base,
'openai_api_key': hosted_model_providers.openai.api_key,
'openai_organization': hosted_model_providers.openai.api_organization
}
else:
return {
'openai_api_base': None,
'openai_api_key': None,
'openai_organization': None
}
@classmethod
def is_provider_type_system_supported(cls) -> bool:
if current_app.config['EDITION'] != 'CLOUD':
return False
if hosted_model_providers.openai:
return True
return False
def should_deduct_quota(self):
if hosted_model_providers.openai \
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
return True
return False
def get_payment_info(self) -> Optional[dict]:
"""
get payment info if it payable.
:return:
"""
if hosted_model_providers.openai \
and hosted_model_providers.openai.paid_enabled:
return {
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
}
return None
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,184 @@
import json
import logging
from typing import Type
import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
from models.provider import ProviderType
class ReplicateProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'replicate'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ReplicateModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ReplicateEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_credentials = self.get_model_credentials(model_name, model_type)
model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
try:
version = model.versions.get(model_credentials['model_version'])
except ReplicateError as e:
raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
f"cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Model validate failed.")
raise e
model_kwargs_rules = ModelKwargsRules()
for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
if key == ['temperature', 'top_p']:
kwarg_rule = KwargRule[float](
type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
else:
model_kwargs_rules.top_p = kwarg_rule
elif key in ['max_length', 'max_new_tokens']:
model_kwargs_rules.max_tokens = KwargRule[int](
alias=key,
type=KwargRuleType.INTEGER.value,
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
)
return model_kwargs_rules
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate API Key must be provided.')
if 'model_version' not in credentials:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
if model_name.count("/") != 1:
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
'format: {user_name}/{model_name}')
version = credentials['model_version']
try:
model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
rst = model.versions.get(version)
if model_type == ModelType.EMBEDDINGS \
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
elif model_type == ModelType.TEXT_GENERATION \
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
except ReplicateError as e:
raise CredentialsValidateFailedError(
f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
logging.exception("Replicate config validation failed.")
raise e
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'replicate_api_token': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['replicate_api_token']:
credentials['replicate_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['replicate_api_token']
)
if obfuscated:
credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@ -0,0 +1,191 @@
import json
import logging
from json import JSONDecodeError
from typing import Type
from flask import current_app
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
from core.third_party.spark.spark_llm import SparkError
from models.provider import ProviderType, ProviderQuotaType
class SparkProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'spark'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'spark',
'name': '星火认知大模型',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = SparkModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'app_id' not in credentials:
raise CredentialsValidateFailedError('Spark app_id must be provided.')
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Spark api_key must be provided.')
if 'api_secret' not in credentials:
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
try:
credential_kwargs = {
'app_id': credentials['app_id'],
'api_key': credentials['api_key'],
'api_secret': credentials['api_secret'],
}
chat_llm = ChatSpark(
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['api_secret'] = encrypter.encrypt_token(tenant_id, credentials['api_secret'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'app_id': None,
'api_key': None,
'api_secret': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['api_secret']:
credentials['api_secret'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_secret']
)
if obfuscated:
credentials['api_secret'] = encrypter.obfuscated_token(credentials['api_secret'])
return credentials
else:
return {
'app_id': None,
'api_key': None,
'api_secret': None,
}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,157 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
from models.provider import ProviderType
class TongyiProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'tongyi'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'qwen-v1',
'name': 'qwen-v1',
},
{
'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = TongyiModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
model_max_tokens = {
'qwen-v1': 1500,
'qwen-plus-v1': 6500
}
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'dashscope_api_key' not in credentials:
raise CredentialsValidateFailedError('Dashscope API Key must be provided.')
try:
credential_kwargs = {
'dashscope_api_key': credentials['dashscope_api_key']
}
llm = EnhanceTongyi(
model_name='qwen-v1',
max_retries=1,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['dashscope_api_key'] = encrypter.encrypt_token(tenant_id, credentials['dashscope_api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'dashscope_api_key': None
}
if credentials['dashscope_api_key']:
credentials['dashscope_api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['dashscope_api_key']
)
if obfuscated:
credentials['dashscope_api_key'] = encrypter.obfuscated_token(credentials['dashscope_api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,182 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
from models.provider import ProviderType
class WenxinProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'wenxin'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = WenxinModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Wenxin api_key must be provided.')
if 'secret_key' not in credentials:
raise CredentialsValidateFailedError('Wenxin secret_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
'secret_key': credentials['secret_key'],
}
llm = Wenxin(
temperature=0.01,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
'secret_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['secret_key']:
credentials['secret_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['secret_key']
)
if obfuscated:
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
return credentials
else:
return {
'api_key': None,
'secret_key': None,
}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -0,0 +1,47 @@
import json
import os
def init_provider_rules():
# Get the absolute path of the subdirectory
subdirectory_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'rules')
# Path to the providers.json file
providers_json_file_path = os.path.join(subdirectory_path, '_providers.json')
try:
# Open the JSON file and read its content
with open(providers_json_file_path, 'r') as json_file:
data = json.load(json_file)
# Store the content in a dictionary with the key as the file name (without extension)
provider_names = data
except FileNotFoundError:
return "JSON file not found or path error"
except json.JSONDecodeError:
return "JSON file decoding error"
# Dictionary to store the content of all JSON files
json_data = {}
try:
# Loop through all files in the directory
for provider_name in provider_names:
filename = provider_name + '.json'
# Path to each JSON file
json_file_path = os.path.join(subdirectory_path, filename)
# Open each JSON file and read its content
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)
# Store the content in the dictionary with the key as the file name (without extension)
json_data[os.path.splitext(filename)[0]] = data
return json_data
except FileNotFoundError:
return "JSON file not found or path error"
except json.JSONDecodeError:
return "JSON file decoding error"
provider_rules = init_provider_rules()

View File

@ -0,0 +1,12 @@
[
"openai",
"azure_openai",
"anthropic",
"minimax",
"tongyi",
"spark",
"wenxin",
"chatglm",
"replicate",
"huggingface_hub"
]

View File

@ -0,0 +1,15 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"trial",
"paid"
],
"quota_unit": "times",
"quota_limit": 1000
},
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@ -0,0 +1,13 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,14 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"trial"
],
"quota_unit": "times",
"quota_limit": 200
},
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@ -0,0 +1,13 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}