feat: optimize hf inference endpoint (#975)

This commit is contained in:
takatost
2023-08-23 19:47:50 +08:00
committed by GitHub
parent 1fc57d7358
commit a76fde3d23
4 changed files with 59 additions and 11 deletions

View File

@ -2,7 +2,6 @@ import json
from typing import Type
from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from models.provider import ProviderType
@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
try:
llm = HuggingFaceEndpoint(
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation",
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if 'task_type' not in credentials:
credentials['task_type'] = 'text-generation'
if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,