mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat: server multi models support (#799)
This commit is contained in:
58
api/core/model_providers/error.py
Normal file
58
api/core/model_providers/error.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class LLMAPIConnectionError(LLMError):
|
||||
"""Raised when the LLM returns API connection error."""
|
||||
description = "API Connection Error"
|
||||
|
||||
|
||||
class LLMAPIUnavailableError(LLMError):
|
||||
"""Raised when the LLM returns API unavailable error."""
|
||||
description = "API Unavailable Error"
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMError):
|
||||
"""Raised when the LLM returns rate limit error."""
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class LLMAuthorizationError(LLMError):
|
||||
"""Raised when the LLM returns authorization error."""
|
||||
description = "Authorization Error"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
description = "Model Currently Not Support"
|
||||
293
api/core/model_providers/model_factory.py
Normal file
293
api/core/model_providers/model_factory.py
Normal file
@ -0,0 +1,293 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model_from_model_config(cls, tenant_id: str,
|
||||
model_config: dict,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
provider_name = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
|
||||
return cls.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1)
|
||||
),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default System Reasoning Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
is_default_model = True
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init text generation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
|
||||
|
||||
try:
|
||||
model_instance = model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except LLMBadRequestError as e:
|
||||
if is_default_model:
|
||||
raise LLMBadRequestError(f"Default model {model_name} is not available. "
|
||||
f"Please check your model provider credentials.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@classmethod
|
||||
def get_embedding_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
|
||||
"""
|
||||
get embedding model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Embedding Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init embedding model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech2text_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
|
||||
"""
|
||||
get speech to text model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Speech-to-Text Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init speech to text model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseProviderModel]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init moderation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
if model_list:
|
||||
model_info = model_list[0]
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=model_provider_name,
|
||||
model_name=model_info['id']
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
break
|
||||
|
||||
return default_model
|
||||
|
||||
@classmethod
|
||||
def update_default_model(cls,
|
||||
tenant_id: str,
|
||||
model_type: ModelType,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
model_provider_name = ModelProviderFactory.get_provider_names()
|
||||
if provider_name not in model_provider_name:
|
||||
raise ValueError(f'Invalid provider name: {provider_name}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
model_ids = [model['id'] for model in model_list]
|
||||
if model_name not in model_ids:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if default_model:
|
||||
# update default model
|
||||
default_model.provider_name = provider_name
|
||||
default_model.model_name = model_name
|
||||
db.session.commit()
|
||||
else:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
return default_model
|
||||
228
api/core/model_providers/model_provider_factory.py
Normal file
228
api/core/model_providers/model_provider_factory.py
Normal file
@ -0,0 +1,228 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.rules import provider_rules
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
DEFAULT_MODELS = {
|
||||
ModelType.TEXT_GENERATION.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
},
|
||||
ModelType.EMBEDDINGS.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
},
|
||||
ModelType.SPEECH_TO_TEXT.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'whisper-1',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
@classmethod
|
||||
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
|
||||
if provider_name == 'openai':
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider
|
||||
elif provider_name == 'anthropic':
|
||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider
|
||||
elif provider_name == 'minimax':
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
return MinimaxProvider
|
||||
elif provider_name == 'spark':
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
return SparkProvider
|
||||
elif provider_name == 'tongyi':
|
||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
||||
return TongyiProvider
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
elif provider_name == 'replicate':
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
return ReplicateProvider
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_provider_names(cls):
|
||||
"""
|
||||
Returns a list of provider names.
|
||||
"""
|
||||
return list(provider_rules.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_rules(cls):
|
||||
"""
|
||||
Returns a list of provider rules.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return provider_rules
|
||||
|
||||
@classmethod
|
||||
def get_provider_rule(cls, provider_name: str):
|
||||
"""
|
||||
Returns provider rule.
|
||||
"""
|
||||
return provider_rules[provider_name]
|
||||
|
||||
@classmethod
|
||||
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred model provider.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider
|
||||
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
|
||||
if not preferred_provider or not preferred_provider.is_valid:
|
||||
return None
|
||||
|
||||
# init model provider
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
return model_provider_class(provider=preferred_provider)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_type_by_preferred_model_provider(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
preferred_model_provider: TenantPreferredModelProvider):
|
||||
"""
|
||||
get preferred provider type by preferred model provider.
|
||||
|
||||
:param model_provider_name:
|
||||
:param preferred_model_provider:
|
||||
:return:
|
||||
"""
|
||||
if not preferred_model_provider:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
support_provider_types = model_provider_rules['support_provider_types']
|
||||
|
||||
if ProviderType.CUSTOM.value in support_provider_types:
|
||||
custom_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.is_valid == True
|
||||
).first()
|
||||
|
||||
if custom_provider:
|
||||
return ProviderType.CUSTOM.value
|
||||
|
||||
model_provider = cls.get_model_provider_class(model_provider_name)
|
||||
|
||||
if ProviderType.SYSTEM.value in support_provider_types \
|
||||
and model_provider.is_provider_type_system_supported():
|
||||
return ProviderType.SYSTEM.value
|
||||
elif ProviderType.CUSTOM.value in support_provider_types:
|
||||
return ProviderType.CUSTOM.value
|
||||
else:
|
||||
return preferred_model_provider.preferred_provider_type
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider type
|
||||
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
|
||||
|
||||
# get providers by preferred provider type
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == preferred_provider_type
|
||||
).all()
|
||||
|
||||
no_system_provider = False
|
||||
if preferred_provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type_to_provider_dict = {}
|
||||
for provider in providers:
|
||||
quota_type_to_provider_dict[provider.quota_type] = provider
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
|
||||
and quota_type in quota_type_to_provider_dict.keys():
|
||||
provider = quota_type_to_provider_dict[quota_type]
|
||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
||||
return provider
|
||||
|
||||
no_system_provider = True
|
||||
|
||||
if no_system_provider:
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
|
||||
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
|
||||
if providers:
|
||||
return providers[0]
|
||||
else:
|
||||
try:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=model_provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider type of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == model_provider_name
|
||||
).first()
|
||||
|
||||
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
|
||||
0
api/core/model_providers/models/__init__.py
Normal file
0
api/core/model_providers/models/__init__.py
Normal file
22
api/core/model_providers/models/base.py
Normal file
22
api/core/model_providers/models/base.py
Normal file
@ -0,0 +1,22 @@
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseProviderModel(ABC):
|
||||
_client: Any
|
||||
_model_provider: BaseModelProvider
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any):
|
||||
self._model_provider = model_provider
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def model_provider(self):
|
||||
return self._model_provider
|
||||
|
||||
@ -0,0 +1,78 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
|
||||
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
deployment=name,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
40
api/core/model_providers/models/embedding/base.py
Normal file
40
api/core/model_providers/models/embedding/base.py
Normal file
@ -0,0 +1,40 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.EMBEDDINGS
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,35 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class MinimaxEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = MiniMaxEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,72 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
max_retries=1,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,36 @@
|
||||
import decimal
|
||||
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class ReplicateEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = ReplicateEmbeddings(
|
||||
model=name + ':' + credentials.get('model_version'),
|
||||
replicate_api_token=credentials.get('replicate_api_token')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
0
api/core/model_providers/models/entity/__init__.py
Normal file
0
api/core/model_providers/models/entity/__init__.py
Normal file
53
api/core/model_providers/models/entity/message.py
Normal file
53
api/core/model_providers/models/entity/message.py
Normal file
@ -0,0 +1,53 @@
|
||||
import enum
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMRunResult(BaseModel):
|
||||
content: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
content: str = ''
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return lc_messages
|
||||
|
||||
|
||||
def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def str_to_prompt_messages(texts: list[str]):
|
||||
prompt_messages = []
|
||||
for text in texts:
|
||||
prompt_messages.append(PromptMessage(content=text))
|
||||
return prompt_messages
|
||||
59
api/core/model_providers/models/entity/model_params.py
Normal file
59
api/core/model_providers/models/entity/model_params.py
Normal file
@ -0,0 +1,59 @@
|
||||
import enum
|
||||
from typing import Optional, TypeVar, Generic
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
TEXT_GENERATION = 'text-generation'
|
||||
EMBEDDINGS = 'embeddings'
|
||||
SPEECH_TO_TEXT = 'speech2text'
|
||||
IMAGE = 'image'
|
||||
VIDEO = 'video'
|
||||
MODERATION = 'moderation'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ModelType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ModelKwargs(BaseModel):
|
||||
max_tokens: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
presence_penalty: Optional[float]
|
||||
frequency_penalty: Optional[float]
|
||||
|
||||
|
||||
class KwargRuleType(enum.Enum):
|
||||
STRING = 'string'
|
||||
INTEGER = 'integer'
|
||||
FLOAT = 'float'
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class KwargRule(Generic[T], BaseModel):
|
||||
enabled: bool = True
|
||||
min: Optional[T] = None
|
||||
max: Optional[T] = None
|
||||
default: Optional[T] = None
|
||||
alias: Optional[str] = None
|
||||
|
||||
|
||||
class ModelKwargsRules(BaseModel):
|
||||
max_tokens: KwargRule = KwargRule[int](enabled=False)
|
||||
temperature: KwargRule = KwargRule[float](enabled=False)
|
||||
top_p: KwargRule = KwargRule[float](enabled=False)
|
||||
presence_penalty: KwargRule = KwargRule[float](enabled=False)
|
||||
frequency_penalty: KwargRule = KwargRule[float](enabled=False)
|
||||
10
api/core/model_providers/models/entity/provider.py
Normal file
10
api/core/model_providers/models/entity/provider.py
Normal file
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderQuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
||||
0
api/core/model_providers/models/llm/__init__.py
Normal file
0
api/core/model_providers/models/llm/__init__.py
Normal file
107
api/core/model_providers/models/llm/anthropic_model.py
Normal file
107
api/core/model_providers/models/llm/anthropic_model.py
Normal file
@ -0,0 +1,107 @@
|
||||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import anthropic
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class AnthropicModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatAnthropic(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
default_request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, anthropic.APIConnectionError):
|
||||
logging.warning("Failed to connect to Anthropic API.")
|
||||
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
|
||||
elif isinstance(ex, anthropic.RateLimitError):
|
||||
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
elif isinstance(ex, anthropic.AuthenticationError):
|
||||
return LLMAuthorizationError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.BadRequestError):
|
||||
return LLMBadRequestError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.APIStatusError):
|
||||
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
||||
177
api/core/model_providers/models/llm/azure_openai_model.py
Normal file
177
api/core/model_providers/models/llm/azure_openai_model.py
Normal file
@ -0,0 +1,177 @@
|
||||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
|
||||
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name == 'text-davinci-003':
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
client = EnhanceAzureOpenAI(
|
||||
deployment_name=self.name,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceAzureChatOpenAI(
|
||||
deployment_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
269
api/core/model_providers/models/llm/base.py
Normal file
269
api/core/model_providers/models/llm/base.py
Normal file
@ -0,0 +1,269 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
name: str
|
||||
model_kwargs: ModelKwargs
|
||||
credentials: dict
|
||||
streaming: bool = False
|
||||
type: ModelType = ModelType.TEXT_GENERATION
|
||||
deduct_quota: bool = True
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.name = name
|
||||
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
|
||||
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
presence_penalty=None,
|
||||
frequency_penalty=None
|
||||
)
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
self.streaming = streaming
|
||||
|
||||
if streaming:
|
||||
default_callback = DifyStreamingStdOutCallbackHandler()
|
||||
else:
|
||||
default_callback = DifyStdOutCallbackHandler()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = [default_callback]
|
||||
else:
|
||||
callbacks.append(default_callback)
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
client = self._init_client()
|
||||
super().__init__(model_provider, client)
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMRunResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = self.callbacks
|
||||
else:
|
||||
callbacks.extend(self.callbacks)
|
||||
|
||||
if 'fake_response' in kwargs and kwargs['fake_response']:
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=kwargs['fake_response'],
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
result = fake_llm.generate([prompts])
|
||||
else:
|
||||
try:
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=completion_content,
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
fake_llm.generate([prompts])
|
||||
|
||||
if result.llm_output and result.llm_output['token_usage']:
|
||||
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
|
||||
completion_tokens = result.llm_output['token_usage']['completion_tokens']
|
||||
total_tokens = result.llm_output['token_usage']['total_tokens']
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(messages)
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.deduct_quota(total_tokens)
|
||||
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
||||
def set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
self.model_kwargs = model_kwargs
|
||||
self._set_model_kwargs(model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
"""
|
||||
Handle llm run exceptions.
|
||||
|
||||
:param ex:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_callbacks(self, callbacks: Callbacks):
|
||||
"""
|
||||
Add callbacks to client.
|
||||
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if not self.client.callbacks:
|
||||
self.client.callbacks = callbacks
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("prompt must not be empty.")
|
||||
|
||||
if not model_mode:
|
||||
model_mode = self.model_mode
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return messages[0].content
|
||||
else:
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
convert model kwargs to provider model kwargs.
|
||||
|
||||
:param model_rules:
|
||||
:param model_kwargs:
|
||||
:return:
|
||||
"""
|
||||
model_kwargs_input = {}
|
||||
for key, value in model_kwargs.dict().items():
|
||||
rule = getattr(model_rules, key)
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if rule.alias:
|
||||
key = rule.alias
|
||||
|
||||
if rule.default is not None and value is None:
|
||||
value = rule.default
|
||||
|
||||
if rule.min is not None:
|
||||
value = max(value, rule.min)
|
||||
|
||||
if rule.max is not None:
|
||||
value = min(value, rule.max)
|
||||
|
||||
model_kwargs_input[key] = value
|
||||
|
||||
return model_kwargs_input
|
||||
70
api/core/model_providers/models/llm/chatglm_model.py
Normal file
70
api/core/model_providers/models/llm/chatglm_model.py
Normal file
@ -0,0 +1,70 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ChatGLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatGLM(
|
||||
callbacks=self.callbacks,
|
||||
endpoint_url=self.credentials.get('api_base'),
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
82
api/core/model_providers/models/llm/huggingface_hub_model.py
Normal file
82
api/core/model_providers/models/llm/huggingface_hub_model.py
Normal file
@ -0,0 +1,82 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
70
api/core/model_providers/models/llm/minimax_model.py
Normal file
70
api/core/model_providers/models/llm/minimax_model.py
Normal file
@ -0,0 +1,70 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
219
api/core/model_providers/models/llm/openai_model.py
Normal file
219
api/core/model_providers/models/llm/openai_model.py
Normal file
@ -0,0 +1,219 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
COMPLETION_MODELS = [
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name in COMPLETION_MODELS:
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
client = EnhanceOpenAI(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
# Fine-tuning is currently only available for the following base models:
|
||||
# davinci, curie, babbage, and ada.
|
||||
# This means that except for the fixed `completion` model,
|
||||
# all other fine-tuned models are `completion` models.
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
#
|
||||
# :return:
|
||||
# """
|
||||
# credentials = self._model_provider.get_credentials()
|
||||
#
|
||||
# try:
|
||||
# result = openai.Model.retrieve(
|
||||
# id=self.name,
|
||||
# api_key=credentials.get('openai_api_key'),
|
||||
# request_timeout=60
|
||||
# )
|
||||
#
|
||||
# if 'id' not in result or result['id'] != self.name:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
|
||||
# except openai.error.OpenAIError as e:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
# except Exception as e:
|
||||
# logging.exception("OpenAI Model retrieve failed.")
|
||||
# raise e
|
||||
103
api/core/model_providers/models/llm/replicate_model.py
Normal file
103
api/core/model_providers/models/llm/replicate_model.py
Normal file
@ -0,0 +1,103 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
from replicate.exceptions import ReplicateError, ModelError
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ReplicateModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
return EnhanceReplicate(
|
||||
model=self.name + ':' + self.credentials.get('model_version'),
|
||||
input=provider_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
replicate_api_token=self.credentials.get('replicate_api_token'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
extra_kwargs = {}
|
||||
if isinstance(prompts, list):
|
||||
system_messages = [message for message in messages if message.type == 'system']
|
||||
if system_messages:
|
||||
system_message = system_messages[0]
|
||||
extra_kwargs['system_prompt'] = system_message.content
|
||||
prompts = [message for message in messages if message.type != 'system']
|
||||
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
# The maximum length the generated tokens can have.
|
||||
# Corresponds to the length of the input prompt + max_new_tokens.
|
||||
if 'max_length' in self._client.input:
|
||||
self._client.input['max_length'] = min(
|
||||
self._client.input['max_length'] + self.get_num_tokens(messages),
|
||||
self.model_rules.max_tokens.max
|
||||
)
|
||||
|
||||
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, list):
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
73
api/core/model_providers/models/llm/spark_model.py
Normal file
73
api/core/model_providers/models/llm/spark_model.py
Normal file
@ -0,0 +1,73 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
from core.third_party.spark.spark_llm import SparkError
|
||||
|
||||
|
||||
class SparkModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, SparkError):
|
||||
return LLMBadRequestError(f"Spark: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
77
api/core/model_providers/models/llm/tongyi_model.py
Normal file
77
api/core/model_providers/models/llm/tongyi_model.py
Normal file
@ -0,0 +1,77 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from requests import HTTPError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
|
||||
|
||||
|
||||
class TongyiModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
return EnhanceTongyi(
|
||||
model_name=self.name,
|
||||
max_retries=1,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ValueError, HTTPError)):
|
||||
return LLMBadRequestError(f"Tongyi: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
92
api/core/model_providers/models/llm/wenxin_model.py
Normal file
92
api/core/model_providers/models/llm/wenxin_model.py
Normal file
@ -0,0 +1,92 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||
|
||||
|
||||
class WenxinModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Wenxin(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'ernie-bot': {
|
||||
'prompt': decimal.Decimal('0.012'),
|
||||
'completion': decimal.Decimal('0.012'),
|
||||
},
|
||||
'ernie-bot-turbo': {
|
||||
'prompt': decimal.Decimal('0.008'),
|
||||
'completion': decimal.Decimal('0.008')
|
||||
},
|
||||
'bloomz-7b': {
|
||||
'prompt': decimal.Decimal('0.006'),
|
||||
'completion': decimal.Decimal('0.006')
|
||||
}
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
DEFAULT_AUDIO_MODEL = 'whisper-1'
|
||||
|
||||
|
||||
class OpenAIModeration(BaseProviderModel):
|
||||
type: ModelType = ModelType.MODERATION
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
super().__init__(model_provider, openai.Moderation)
|
||||
|
||||
def run(self, text):
|
||||
credentials = self.model_provider.get_model_credentials(
|
||||
model_name=DEFAULT_AUDIO_MODEL,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
try:
|
||||
return self._client.create(input=text, api_key=credentials['openai_api_key'])
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
29
api/core/model_providers/models/speech2text/base.py
Normal file
29
api/core/model_providers/models/speech2text/base.py
Normal file
@ -0,0 +1,29 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseSpeech2Text(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.SPEECH_TO_TEXT
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def run(self, file):
|
||||
try:
|
||||
return self._run(file)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, file):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class OpenAIWhisper(BaseSpeech2Text):
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
super().__init__(model_provider, openai.Audio, name)
|
||||
|
||||
def _run(self, file):
|
||||
credentials = self.model_provider.get_model_credentials(
|
||||
model_name=self.name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
return self._client.transcribe(
|
||||
model=self.name,
|
||||
file=file,
|
||||
api_key=credentials.get('openai_api_key'),
|
||||
api_base=credentials.get('openai_api_base'),
|
||||
organization=credentials.get('openai_organization'),
|
||||
)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
0
api/core/model_providers/providers/__init__.py
Normal file
0
api/core/model_providers/providers/__init__.py
Normal file
224
api/core/model_providers/providers/anthropic_provider.py
Normal file
224
api/core/model_providers/providers/anthropic_provider.py
Normal file
@ -0,0 +1,224 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type, Optional
|
||||
|
||||
import anthropic
|
||||
from flask import current_app
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'anthropic'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = AnthropicModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'anthropic_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Anthropic API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'anthropic_api_key': credentials['anthropic_api_key']
|
||||
}
|
||||
|
||||
if 'anthropic_api_url' in credentials:
|
||||
credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'anthropic_api_url': None,
|
||||
'anthropic_api_key': None
|
||||
}
|
||||
|
||||
if credentials['anthropic_api_key']:
|
||||
credentials['anthropic_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['anthropic_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key'])
|
||||
|
||||
if 'anthropic_api_url' not in credentials:
|
||||
credentials['anthropic_api_url'] = None
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.anthropic:
|
||||
return {
|
||||
'anthropic_api_url': hosted_model_providers.anthropic.api_base,
|
||||
'anthropic_api_key': hosted_model_providers.anthropic.api_key,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'anthropic_api_url': None,
|
||||
'anthropic_api_key': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.anthropic:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.anthropic and \
|
||||
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.anthropic \
|
||||
and hosted_model_providers.anthropic.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
387
api/core/model_providers/providers/azure_openai_provider.py
Normal file
387
api/core/model_providers/providers/azure_openai_provider.py
Normal file
@ -0,0 +1,387 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
import openai
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
|
||||
AZURE_OPENAI_API_VERSION
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType, ProviderModel, ProviderQuotaType
|
||||
|
||||
BASE_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k',
|
||||
'text-davinci-003',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'azure_openai'
|
||||
|
||||
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
# convert old provider config to provider models
|
||||
self._convert_provider_config_to_model_config()
|
||||
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
# get configurable provider models
|
||||
provider_models = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
model_list = []
|
||||
for provider_model in provider_models:
|
||||
model_dict = {
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['base_model_name'] in [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k',
|
||||
]:
|
||||
model_dict['features'] = [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
|
||||
model_list.append(model_dict)
|
||||
else:
|
||||
model_list = self._get_fixed_model_list(model_type)
|
||||
|
||||
return model_list
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
}
|
||||
]
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text-embedding-ada-002',
|
||||
'name': 'text-embedding-ada-002'
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = AzureOpenAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = AzureOpenAIEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
base_model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-35-turbo': 4096,
|
||||
'gpt-35-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
|
||||
model_credentials['base_model_name'],
|
||||
4097
|
||||
), default=16),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
||||
|
||||
if 'openai_api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
||||
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
if credentials['base_model_name'] not in BASE_MODELS:
|
||||
raise CredentialsValidateFailedError('Base Model Name is invalid')
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
try:
|
||||
client = EnhanceAzureChatOpenAI(
|
||||
deployment_name=model_name,
|
||||
temperature=0,
|
||||
max_tokens=15,
|
||||
request_timeout=10,
|
||||
openai_api_type='azure',
|
||||
openai_api_version='2023-07-01-preview',
|
||||
openai_api_key=credentials['openai_api_key'],
|
||||
openai_api_base=credentials['openai_api_base'],
|
||||
)
|
||||
|
||||
client.generate([[HumanMessage(content='hi!')]])
|
||||
except openai.error.OpenAIError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Azure OpenAI Model retrieve failed.")
|
||||
raise e
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
try:
|
||||
client = OpenAIEmbeddings(
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
deployment=model_name,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
openai_api_key=credentials['openai_api_key'],
|
||||
openai_api_base=credentials['openai_api_base']
|
||||
)
|
||||
|
||||
client.embed_query('hi')
|
||||
except openai.error.OpenAIError as e:
|
||||
logging.exception("Azure OpenAI Model check error.")
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Azure OpenAI Model retrieve failed.")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
# convert old provider config to provider models
|
||||
self._convert_provider_config_to_model_config()
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': '',
|
||||
'base_model_name': ''
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['openai_api_key']:
|
||||
credentials['openai_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['openai_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.azure_openai:
|
||||
return {
|
||||
'openai_api_base': hosted_model_providers.azure_openai.api_base,
|
||||
'openai_api_key': hosted_model_providers.azure_openai.api_key,
|
||||
'base_model_name': model_name
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': None,
|
||||
'base_model_name': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.azure_openai:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.azure_openai \
|
||||
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
|
||||
def _convert_provider_config_to_model_config(self):
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
and self.provider.is_valid \
|
||||
and self.provider.encrypted_config:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': '',
|
||||
'base_model_name': ''
|
||||
}
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-35-turbo',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-35-turbo-16k',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='gpt-4',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-davinci-003',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self._add_provider_model(
|
||||
model_name='text-embedding-ada-002',
|
||||
model_type=ModelType.EMBEDDINGS,
|
||||
provider_credentials=credentials
|
||||
)
|
||||
|
||||
self.provider.encrypted_config = None
|
||||
db.session.commit()
|
||||
|
||||
def _add_provider_model(self, model_name: str, model_type: ModelType, provider_credentials: dict):
|
||||
credentials = provider_credentials.copy()
|
||||
credentials['base_model_name'] = model_name
|
||||
provider_model = ProviderModel(
|
||||
tenant_id=self.provider.tenant_id,
|
||||
provider_name=self.provider.provider_name,
|
||||
model_name=model_name,
|
||||
model_type=model_type.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model)
|
||||
db.session.commit()
|
||||
283
api/core/model_providers/providers/base.py
Normal file
283
api/core/model_providers/providers/base.py
Normal file
@ -0,0 +1,283 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import QuotaExceededError, LLMBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.entity.provider import ProviderQuotaUnit
|
||||
from core.model_providers.rules import provider_rules
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
class BaseModelProvider(BaseModel, ABC):
|
||||
|
||||
provider: Provider
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_rules(self):
|
||||
"""
|
||||
Returns the rules of a provider.
|
||||
"""
|
||||
return provider_rules[self.provider_name]
|
||||
|
||||
def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
"""
|
||||
get supported model object list for use.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
rules = self.get_rules()
|
||||
if 'custom' not in rules['support_provider_types']:
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
if 'model_flexibility' not in rules:
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
if rules['model_flexibility'] == 'fixed':
|
||||
return self._get_fixed_model_list(model_type)
|
||||
|
||||
# get configurable provider models
|
||||
provider_models = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
return [{
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
} for provider_model in provider_models]
|
||||
|
||||
@abstractmethod
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
"""
|
||||
get supported model object list for use.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_type: ModelType) -> Type:
|
||||
"""
|
||||
get specific model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
check provider credentials valid.
|
||||
|
||||
:param credentials:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt provider credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
return current_app.config['EDITION'] == 'CLOUD'
|
||||
|
||||
def check_quota_over_limit(self):
|
||||
"""
|
||||
check provider quota over limit.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
||||
return
|
||||
|
||||
rules = self.get_rules()
|
||||
if 'system' not in rules['support_provider_types']:
|
||||
return
|
||||
|
||||
provider = db.session.query(Provider).filter(
|
||||
db.and_(
|
||||
Provider.id == self.provider.id,
|
||||
Provider.is_valid == True,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
)
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise QuotaExceededError()
|
||||
|
||||
def deduct_quota(self, used_tokens: int = 0) -> None:
|
||||
"""
|
||||
deduct available quota when provider type is system or paid.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.SYSTEM.value:
|
||||
return
|
||||
|
||||
rules = self.get_rules()
|
||||
if 'system' not in rules['support_provider_types']:
|
||||
return
|
||||
|
||||
if not self.should_deduct_quota():
|
||||
return
|
||||
|
||||
if 'system_config' not in rules:
|
||||
quota_unit = ProviderQuotaUnit.TIMES.value
|
||||
elif 'quota_unit' not in rules['system_config']:
|
||||
quota_unit = ProviderQuotaUnit.TIMES.value
|
||||
else:
|
||||
quota_unit = rules['system_config']['quota_unit']
|
||||
|
||||
if quota_unit == ProviderQuotaUnit.TOKENS.value:
|
||||
used_quota = used_tokens
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.provider.tenant_id,
|
||||
Provider.provider_name == self.provider.provider_name,
|
||||
Provider.provider_type == self.provider.provider_type,
|
||||
Provider.quota_type == self.provider.quota_type,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return False
|
||||
|
||||
def update_last_used(self) -> None:
|
||||
"""
|
||||
update last used time.
|
||||
|
||||
:return:
|
||||
"""
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.provider.tenant_id,
|
||||
Provider.provider_name == self.provider.provider_name
|
||||
).update({'last_used': datetime.utcnow()})
|
||||
db.session.commit()
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get product info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
|
||||
"""
|
||||
get provider model.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
provider_model = db.session.query(ProviderModel).filter(
|
||||
ProviderModel.tenant_id == self.provider.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider_name,
|
||||
ProviderModel.model_name == model_name,
|
||||
ProviderModel.model_type == model_type.value,
|
||||
ProviderModel.is_valid == True
|
||||
).first()
|
||||
|
||||
if not provider_model:
|
||||
raise LLMBadRequestError(f"The model {model_name} does not exist. "
|
||||
f"Please check the configuration.")
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
class CredentialsValidateFailedError(Exception):
|
||||
pass
|
||||
157
api/core/model_providers/providers/chatglm_provider.py
Normal file
157
api/core/model_providers/providers/chatglm_provider.py
Normal file
@ -0,0 +1,157 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class ChatGLMProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'chatglm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ChatGLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'chatglm-6b': 2000,
|
||||
'chatglm2-6b': 32000,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_base' not in credentials:
|
||||
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'endpoint_url': credentials['api_base']
|
||||
}
|
||||
|
||||
llm = ChatGLM(
|
||||
max_token=10,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_base': None
|
||||
}
|
||||
|
||||
if credentials['api_base']:
|
||||
credentials['api_base'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_base']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
76
api/core/model_providers/providers/hosted.py
Normal file
76
api/core/model_providers/providers/hosted.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HostedOpenAI(BaseModel):
|
||||
api_base: str = None
|
||||
api_organization: str = None
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the openai hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedAzureOpenAI(BaseModel):
|
||||
api_base: str
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
|
||||
|
||||
|
||||
class HostedAnthropic(BaseModel):
|
||||
api_base: str = None
|
||||
api_key: str
|
||||
quota_limit: int = 0
|
||||
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
|
||||
paid_enabled: bool = False
|
||||
paid_stripe_price_id: str = None
|
||||
paid_increase_quota: int = 1
|
||||
|
||||
|
||||
class HostedModelProviders(BaseModel):
|
||||
openai: Optional[HostedOpenAI] = None
|
||||
azure_openai: Optional[HostedAzureOpenAI] = None
|
||||
anthropic: Optional[HostedAnthropic] = None
|
||||
|
||||
|
||||
hosted_model_providers = HostedModelProviders()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
|
||||
if app.config.get("HOSTED_OPENAI_ENABLED"):
|
||||
hosted_model_providers.openai = HostedOpenAI(
|
||||
api_base=app.config.get("HOSTED_OPENAI_API_BASE"),
|
||||
api_organization=app.config.get("HOSTED_OPENAI_API_ORGANIZATION"),
|
||||
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
hosted_model_providers.azure_openai = HostedAzureOpenAI(
|
||||
api_base=app.config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
api_key=app.config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT"),
|
||||
)
|
||||
|
||||
if app.config.get("HOSTED_ANTHROPIC_ENABLED"):
|
||||
hosted_model_providers.anthropic = HostedAnthropic(
|
||||
api_base=app.config.get("HOSTED_ANTHROPIC_API_BASE"),
|
||||
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
|
||||
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
|
||||
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
|
||||
)
|
||||
183
api/core/model_providers/providers/huggingface_hub_provider.py
Normal file
183
api/core/model_providers/providers/huggingface_hub_provider.py
Normal file
@ -0,0 +1,183 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class HuggingfaceHubProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'huggingface_hub'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = HuggingfaceHubModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if model_type != ModelType.TEXT_GENERATION:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'huggingfacehub_api_type' not in credentials \
|
||||
or credentials['huggingfacehub_api_type'] not in ['hosted_inference_api', 'inference_endpoints']:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub API Type invalid, '
|
||||
'must be hosted_inference_api or inference_endpoints.')
|
||||
|
||||
if 'huggingfacehub_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub API Token must be provided.')
|
||||
|
||||
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
|
||||
|
||||
try:
|
||||
hfapi.whoami()
|
||||
except Exception:
|
||||
raise CredentialsValidateFailedError("Invalid API Token.")
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
else:
|
||||
try:
|
||||
model_info = hfapi.model_info(repo_id=model_name)
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['huggingfacehub_api_token'] = encrypter.encrypt_token(tenant_id, credentials['huggingfacehub_api_token'])
|
||||
|
||||
if credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
||||
hfapi = HfApi(token=credentials['huggingfacehub_api_token'])
|
||||
model_info = hfapi.model_info(repo_id=model_name)
|
||||
if not model_info:
|
||||
raise ValueError(f'Model {model_name} not found.')
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
credentials['task_type'] = model_info.pipeline_tag
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'huggingfacehub_api_token': None,
|
||||
'task_type': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.obfuscated_token(credentials['huggingfacehub_api_token'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
179
api/core/model_providers/providers/minimax_provider.py
Normal file
179
api/core/model_providers/providers/minimax_provider.py
Normal file
@ -0,0 +1,179 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Minimax
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.minimax_model import MinimaxModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class MinimaxProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'minimax'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'abab5.5-chat',
|
||||
'name': 'abab5.5-chat',
|
||||
},
|
||||
{
|
||||
'id': 'abab5-chat',
|
||||
'name': 'abab5-chat',
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'embo-01',
|
||||
'name': 'embo-01',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = MinimaxModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = MinimaxEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'abab5.5-chat': 16384,
|
||||
'abab5-chat': 6144,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.95),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'minimax_group_id' not in credentials:
|
||||
raise CredentialsValidateFailedError('MiniMax Group ID must be provided.')
|
||||
|
||||
if 'minimax_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('MiniMax API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'minimax_group_id': credentials['minimax_group_id'],
|
||||
'minimax_api_key': credentials['minimax_api_key'],
|
||||
}
|
||||
|
||||
llm = Minimax(
|
||||
model='abab5.5-chat',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['minimax_api_key'] = encrypter.encrypt_token(tenant_id, credentials['minimax_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'minimax_group_id': None,
|
||||
'minimax_api_key': None,
|
||||
}
|
||||
|
||||
if credentials['minimax_api_key']:
|
||||
credentials['minimax_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['minimax_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['minimax_api_key'] = encrypter.obfuscated_token(credentials['minimax_api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
289
api/core/model_providers/providers/openai_provider.py
Normal file
289
api/core/model_providers/providers/openai_provider.py
Normal file
@ -0,0 +1,289 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
import openai
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openai'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
}
|
||||
]
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
return [
|
||||
{
|
||||
'id': 'text-embedding-ada-002',
|
||||
'name': 'text-embedding-ada-002'
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.SPEECH_TO_TEXT:
|
||||
return [
|
||||
{
|
||||
'id': 'whisper-1',
|
||||
'name': 'whisper-1'
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'text-moderation-stable',
|
||||
'name': 'text-moderation-stable'
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenAIModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = OpenAIEmbedding
|
||||
elif model_type == ModelType.MODERATION:
|
||||
model_class = OpenAIModeration
|
||||
elif model_type == ModelType.SPEECH_TO_TEXT:
|
||||
model_class = OpenAIWhisper
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'openai_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenAI API key is required')
|
||||
|
||||
try:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['openai_api_key']
|
||||
}
|
||||
|
||||
if 'openai_api_base' in credentials and credentials['openai_api_base']:
|
||||
credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1'
|
||||
|
||||
if 'openai_organization' in credentials:
|
||||
credentials_kwargs['organization'] = credentials['openai_organization']
|
||||
|
||||
openai.ChatCompletion.create(
|
||||
messages=[{"role": "user", "content": 'ping'}],
|
||||
model='gpt-3.5-turbo',
|
||||
timeout=10,
|
||||
request_timeout=(5, 30),
|
||||
max_tokens=20,
|
||||
**credentials_kwargs
|
||||
)
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': self.provider.encrypted_config,
|
||||
'openai_organization': None
|
||||
}
|
||||
|
||||
if credentials['openai_api_key']:
|
||||
credentials['openai_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['openai_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
|
||||
|
||||
if 'openai_api_base' not in credentials or not credentials['openai_api_base']:
|
||||
credentials['openai_api_base'] = None
|
||||
else:
|
||||
credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1'
|
||||
|
||||
if 'openai_organization' not in credentials:
|
||||
credentials['openai_organization'] = None
|
||||
|
||||
return credentials
|
||||
else:
|
||||
if hosted_model_providers.openai:
|
||||
return {
|
||||
'openai_api_base': hosted_model_providers.openai.api_base,
|
||||
'openai_api_key': hosted_model_providers.openai.api_key,
|
||||
'openai_organization': hosted_model_providers.openai.api_organization
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'openai_api_base': None,
|
||||
'openai_api_key': None,
|
||||
'openai_organization': None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_provider_type_system_supported(cls) -> bool:
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return False
|
||||
|
||||
if hosted_model_providers.openai:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_payment_info(self) -> Optional[dict]:
|
||||
"""
|
||||
get payment info if it payable.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if hosted_model_providers.openai \
|
||||
and hosted_model_providers.openai.paid_enabled:
|
||||
return {
|
||||
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
|
||||
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
184
api/core/model_providers/providers/replicate_provider.py
Normal file
184
api/core/model_providers/providers/replicate_provider.py
Normal file
@ -0,0 +1,184 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import replicate
|
||||
from replicate.exceptions import ReplicateError
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.replicate_model import ReplicateModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class ReplicateProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'replicate'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = ReplicateModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = ReplicateEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_credentials = self.get_model_credentials(model_name, model_type)
|
||||
|
||||
model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
|
||||
|
||||
try:
|
||||
version = model.versions.get(model_credentials['model_version'])
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
|
||||
f"cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Model validate failed.")
|
||||
raise e
|
||||
|
||||
model_kwargs_rules = ModelKwargsRules()
|
||||
for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
|
||||
if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
|
||||
if key == ['temperature', 'top_p']:
|
||||
kwarg_rule = KwargRule[float](
|
||||
type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
|
||||
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
|
||||
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
|
||||
default=float(value.get('default')) if value.get('default') is not None else None,
|
||||
)
|
||||
if key == 'temperature':
|
||||
model_kwargs_rules.temperature = kwarg_rule
|
||||
else:
|
||||
model_kwargs_rules.top_p = kwarg_rule
|
||||
elif key in ['max_length', 'max_new_tokens']:
|
||||
model_kwargs_rules.max_tokens = KwargRule[int](
|
||||
alias=key,
|
||||
type=KwargRuleType.INTEGER.value,
|
||||
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
|
||||
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
|
||||
default=int(value.get('default')) if value.get('default') is not None else 500,
|
||||
)
|
||||
|
||||
return model_kwargs_rules
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate API Key must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
|
||||
if model_name.count("/") != 1:
|
||||
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
|
||||
'format: {user_name}/{model_name}')
|
||||
|
||||
version = credentials['model_version']
|
||||
try:
|
||||
model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
|
||||
rst = model.versions.get(version)
|
||||
|
||||
if model_type == ModelType.EMBEDDINGS \
|
||||
and 'Embedding' not in rst.openapi_schema['components']['schemas']:
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
|
||||
elif model_type == ModelType.TEXT_GENERATION \
|
||||
and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items']
|
||||
or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
|
||||
raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
logging.exception("Replicate config validation failed.")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'replicate_api_token': None,
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['replicate_api_token']:
|
||||
credentials['replicate_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['replicate_api_token']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
||||
191
api/core/model_providers/providers/spark_provider.py
Normal file
191
api/core/model_providers/providers/spark_provider.py
Normal file
@ -0,0 +1,191 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.spark_model import SparkModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
from core.third_party.spark.spark_llm import SparkError
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
class SparkProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'spark'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': '星火认知大模型',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = SparkModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.5),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'app_id' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark app_id must be provided.')
|
||||
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_key must be provided.')
|
||||
|
||||
if 'api_secret' not in credentials:
|
||||
raise CredentialsValidateFailedError('Spark api_secret must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'app_id': credentials['app_id'],
|
||||
'api_key': credentials['api_key'],
|
||||
'api_secret': credentials['api_secret'],
|
||||
}
|
||||
|
||||
chat_llm = ChatSpark(
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except SparkError as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('Spark config validation failed')
|
||||
raise ex
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['api_secret'] = encrypter.encrypt_token(tenant_id, credentials['api_secret'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'app_id': None,
|
||||
'api_key': None,
|
||||
'api_secret': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['api_secret']:
|
||||
credentials['api_secret'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_secret']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_secret'] = encrypter.obfuscated_token(credentials['api_secret'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {
|
||||
'app_id': None,
|
||||
'api_key': None,
|
||||
'api_secret': None,
|
||||
}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
157
api/core/model_providers/providers/tongyi_provider.py
Normal file
157
api/core/model_providers/providers/tongyi_provider.py
Normal file
@ -0,0 +1,157 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.tongyi_model import TongyiModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class TongyiProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'tongyi'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'qwen-v1',
|
||||
'name': 'qwen-v1',
|
||||
},
|
||||
{
|
||||
'id': 'qwen-plus-v1',
|
||||
'name': 'qwen-plus-v1',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = TongyiModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'qwen-v1': 1500,
|
||||
'qwen-plus-v1': 6500
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'dashscope_api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Dashscope API Key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'dashscope_api_key': credentials['dashscope_api_key']
|
||||
}
|
||||
|
||||
llm = EnhanceTongyi(
|
||||
model_name='qwen-v1',
|
||||
max_retries=1,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['dashscope_api_key'] = encrypter.encrypt_token(tenant_id, credentials['dashscope_api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'dashscope_api_key': None
|
||||
}
|
||||
|
||||
if credentials['dashscope_api_key']:
|
||||
credentials['dashscope_api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['dashscope_api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['dashscope_api_key'] = encrypter.obfuscated_token(credentials['dashscope_api_key'])
|
||||
|
||||
return credentials
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
182
api/core/model_providers/providers/wenxin_provider.py
Normal file
182
api/core/model_providers/providers/wenxin_provider.py
Normal file
@ -0,0 +1,182 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.wenxin_model import WenxinModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class WenxinProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'wenxin'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'ernie-bot',
|
||||
'name': 'ERNIE-Bot',
|
||||
},
|
||||
{
|
||||
'id': 'ernie-bot-turbo',
|
||||
'name': 'ERNIE-Bot-turbo',
|
||||
},
|
||||
{
|
||||
'id': 'bloomz-7b',
|
||||
'name': 'BLOOMZ-7B',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = WenxinModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
|
||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](enabled=False),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Wenxin api_key must be provided.')
|
||||
|
||||
if 'secret_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Wenxin secret_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
'secret_key': credentials['secret_key'],
|
||||
}
|
||||
|
||||
llm = Wenxin(
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['secret_key']:
|
||||
credentials['secret_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['secret_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
47
api/core/model_providers/rules.py
Normal file
47
api/core/model_providers/rules.py
Normal file
@ -0,0 +1,47 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def init_provider_rules():
|
||||
# Get the absolute path of the subdirectory
|
||||
subdirectory_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'rules')
|
||||
|
||||
# Path to the providers.json file
|
||||
providers_json_file_path = os.path.join(subdirectory_path, '_providers.json')
|
||||
|
||||
try:
|
||||
# Open the JSON file and read its content
|
||||
with open(providers_json_file_path, 'r') as json_file:
|
||||
data = json.load(json_file)
|
||||
# Store the content in a dictionary with the key as the file name (without extension)
|
||||
provider_names = data
|
||||
except FileNotFoundError:
|
||||
return "JSON file not found or path error"
|
||||
except json.JSONDecodeError:
|
||||
return "JSON file decoding error"
|
||||
|
||||
# Dictionary to store the content of all JSON files
|
||||
json_data = {}
|
||||
|
||||
try:
|
||||
# Loop through all files in the directory
|
||||
for provider_name in provider_names:
|
||||
filename = provider_name + '.json'
|
||||
|
||||
# Path to each JSON file
|
||||
json_file_path = os.path.join(subdirectory_path, filename)
|
||||
|
||||
# Open each JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
data = json.load(json_file)
|
||||
# Store the content in the dictionary with the key as the file name (without extension)
|
||||
json_data[os.path.splitext(filename)[0]] = data
|
||||
|
||||
return json_data
|
||||
except FileNotFoundError:
|
||||
return "JSON file not found or path error"
|
||||
except json.JSONDecodeError:
|
||||
return "JSON file decoding error"
|
||||
|
||||
|
||||
provider_rules = init_provider_rules()
|
||||
12
api/core/model_providers/rules/_providers.json
Normal file
12
api/core/model_providers/rules/_providers.json
Normal file
@ -0,0 +1,12 @@
|
||||
[
|
||||
"openai",
|
||||
"azure_openai",
|
||||
"anthropic",
|
||||
"minimax",
|
||||
"tongyi",
|
||||
"spark",
|
||||
"wenxin",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub"
|
||||
]
|
||||
15
api/core/model_providers/rules/anthropic.json
Normal file
15
api/core/model_providers/rules/anthropic.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"trial",
|
||||
"paid"
|
||||
],
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 1000
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
7
api/core/model_providers/rules/azure_openai.json
Normal file
7
api/core/model_providers/rules/azure_openai.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
7
api/core/model_providers/rules/chatglm.json
Normal file
7
api/core/model_providers/rules/chatglm.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
7
api/core/model_providers/rules/huggingface_hub.json
Normal file
7
api/core/model_providers/rules/huggingface_hub.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
13
api/core/model_providers/rules/minimax.json
Normal file
13
api/core/model_providers/rules/minimax.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"free"
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
14
api/core/model_providers/rules/openai.json
Normal file
14
api/core/model_providers/rules/openai.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"trial"
|
||||
],
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 200
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
7
api/core/model_providers/rules/replicate.json
Normal file
7
api/core/model_providers/rules/replicate.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
||||
13
api/core/model_providers/rules/spark.json
Normal file
13
api/core/model_providers/rules/spark.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"system",
|
||||
"custom"
|
||||
],
|
||||
"system_config": {
|
||||
"supported_quota_types": [
|
||||
"free"
|
||||
],
|
||||
"quota_unit": "tokens"
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
7
api/core/model_providers/rules/tongyi.json
Normal file
7
api/core/model_providers/rules/tongyi.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
7
api/core/model_providers/rules/wenxin.json
Normal file
7
api/core/model_providers/rules/wenxin.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
}
|
||||
Reference in New Issue
Block a user