mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
feat: optimize hf inference endpoint (#975)
This commit is contained in:
@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
credentials['task_type'] = 'text-generation'
|
||||
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
|
||||
Reference in New Issue
Block a user