Compare commits

...

19 Commits

Author SHA1 Message Date
8e837dde1a feat: bump version to 0.3.18 (#1000) 2023-08-24 18:13:18 +08:00
9ae91a2ec3 feat: optimize xinference request max token key and stop reason (#998) 2023-08-24 18:11:15 +08:00
276d3d10a0 fix: apps loading issue (#994) 2023-08-24 17:57:38 +08:00
f13623184a fix style in app share (#995) 2023-08-24 17:57:25 +08:00
ef61e1487f fix: safetensor arm complie error (#996) 2023-08-24 17:38:10 +08:00
701e2b334f feat: remove unnecessary prompt of baichuan (#993) 2023-08-24 15:30:59 +08:00
6ebd6e7890 feat: bump version to 0.3.17 (#992) 2023-08-24 15:12:47 +08:00
bd3a9b2f8d fix: xinference-chat-stream-response (#991) 2023-08-24 14:39:34 +08:00
18d3877151 feat: optimize xinference stream (#989) 2023-08-24 13:58:34 +08:00
53e83d8697 feat: optimize baichuan prompt (#988) 2023-08-24 12:07:10 +08:00
6377fc75c6 chore: update lintrc config (#986) 2023-08-24 11:46:59 +08:00
2c30d19cbe feat: add baichuan prompt (#985) 2023-08-24 10:22:36 +08:00
9b247fccd4 feat: adjust hf max tokens (#979) 2023-08-23 22:24:50 +08:00
3d38aa7138 feat: bump version to 0.3.16 2023-08-23 20:16:54 +08:00
7d2552b3f2 feat: upgrade xinference to 0.2.1 which support stream response (#977) 2023-08-23 20:15:45 +08:00
117a209ad4 Fix:condition for dataset availability check (#973) 2023-08-23 19:57:27 +08:00
071e7800a0 fix: add hf task field (#976)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-23 19:48:31 +08:00
a76fde3d23 feat: optimize hf inference endpoint (#975) 2023-08-23 19:47:50 +08:00
1fc57d7358 normalize embedding (#974)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-23 19:10:11 +08:00
32 changed files with 424 additions and 231 deletions

View File

@ -100,7 +100,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.15"
self.CURRENT_VERSION = "0.3.18"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@ -130,13 +130,12 @@ class Completion:
fake_response = agent_execute_result.output
# get llm prompt
prompt_messages, stop_words = cls.get_main_llm_prompt(
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=agent_execute_result,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
@ -154,113 +153,6 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
human_inputs = {
"query": query
}
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
@ -307,13 +199,12 @@ And answer according to the language of the user's question.
max_tokens = 0
# get prompt without memory and context
prompt_messages, _ = cls.get_main_llm_prompt(
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=None,
query=query,
context=None,
memory=None
)
@ -358,13 +249,12 @@ And answer according to the language of the user's question.
)
# get llm prompt
old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion",
model=app_model_config.model_dict,
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
agent_execute_result=None,
query=message.query,
context=None,
memory=None
)

View File

@ -1,6 +1,7 @@
import logging
from typing import List
import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
i = 0
normalized_embedding_results = []
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally:
i += 1
text_embeddings.extend(embedding_results)
text_embeddings.extend(normalized_embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
try:
embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
return embedding_results

View File

@ -1,17 +1,24 @@
import json
import os
import re
from abc import abstractmethod
from typing import List, Optional, Any, Union
from typing import List, Optional, Any, Union, Tuple
import decimal
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = rules['price_config'][
self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
completion_tokens = self.get_num_tokens(
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used()
@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
def support_streaming(cls):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:

View File

@ -1,16 +1,14 @@
import decimal
from functools import wraps
from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
class HuggingfaceHubModel(BaseLLM):
@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint(
client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation',
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks,
callbacks=self.callbacks
)
else:
client = HuggingFaceHub(
@ -62,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs

View File

@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@ -2,7 +2,6 @@ import json
from typing import Type
from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from models.provider import ProviderType
@ -51,7 +51,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
)
@classmethod
@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
try:
llm = HuggingFaceEndpoint(
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation",
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if 'task_type' not in credentials:
credentials['task_type'] = 'text-generation'
if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id,

View File

@ -2,7 +2,6 @@ import json
from typing import Type
import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)

View File

@ -0,0 +1,13 @@
{
"human_prefix": "用户",
"assistant_prefix": "助手",
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n",
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt",
"histories_prompt"
],
"query_prompt": "用户:{{query}}",
"stops": ["用户:"]
}

View File

@ -0,0 +1,9 @@
{
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt"
],
"query_prompt": "{{query}}",
"stops": null
}

View File

@ -0,0 +1,13 @@
{
"human_prefix": "Human",
"assistant_prefix": "Assistant",
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
"histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{histories}}\n</histories>\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt",
"histories_prompt"
],
"query_prompt": "Human: {{query}}\n\nAssistant: ",
"stops": ["\nHuman:", "</histories>"]
}

View File

@ -0,0 +1,9 @@
{
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt"
],
"query_prompt": "{{query}}",
"stops": null
}

View File

@ -0,0 +1,39 @@
from typing import Dict
from langchain.llms import HuggingFaceEndpoint
from pydantic import Extra, root_validator
from langchain.utils import get_from_dict_or_env
class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
"""HuggingFace Endpoint models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceEndpoint
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
)
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token="my-api-key"
)
"""
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values

View File

@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]:
"""
Args:
@ -108,12 +108,12 @@ class XinferenceLLM(Xinference):
Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
@ -123,7 +123,16 @@ class XinferenceLLM(Xinference):
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
if 'finish_reason' in choice and choice['finish_reason'] \
and choice['finish_reason'] in ['stop', 'length']:
break
if 'text' in choice:
token = choice.get("text", "")
elif 'delta' in choice and 'content' in choice['delta']:
token = choice.get('delta').get('content')
else:
continue
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(

View File

@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.0
xinference==0.2.1
safetensors==0.3.2

View File

@ -19,7 +19,7 @@ from models.dataset import Dataset, DocumentSegment, DatasetQuery
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
"content": query,

View File

@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'valid_key',
'huggingfacehub_endpoint_url': 'valid_url'
'huggingfacehub_endpoint_url': 'valid_url',
'task_type': 'text-generation'
}
def encrypt_side_effect(tenant_id, encrypt_key):

View File

@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.15
image: langgenius/dify-api:0.3.18
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -124,7 +124,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.15
image: langgenius/dify-api:0.3.18
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -176,7 +176,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.15
image: langgenius/dify-web:0.3.18
restart: always
environment:
EDITION: SELF_HOSTED

View File

@ -1,7 +1,7 @@
{
"extends": [
"@antfu",
"plugin:react-hooks/recommended"
"next",
"@antfu"
],
"rules": {
"@typescript-eslint/consistent-type-definitions": [

2
web/.gitignore vendored
View File

@ -15,6 +15,8 @@
# production
/build
/.history
# misc
.DS_Store
*.pem

View File

@ -1,15 +1,14 @@
'use client'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import { debounce } from 'lodash-es'
import { useTranslation } from 'react-i18next'
import AppCard from './AppCard'
import NewAppCard from './NewAppCard'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import { useAppContext, useSelector } from '@/context/app-context'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import Confirm from '@/app/components/base/confirm/common'
@ -24,15 +23,18 @@ const Apps = () => {
const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext()
const { data, isLoading, setSize, mutate } = useSWRInfinite(getKey, fetchAppList, { revalidateFirstPage: false })
const loadingStateRef = useRef(false)
const pageContainerRef = useSelector(state => state.pageContainerRef)
const anchorRef = useRef<HTMLAnchorElement>(null)
const anchorRef = useRef<HTMLDivElement>(null)
const searchParams = useSearchParams()
const router = useRouter()
const payProviderName = searchParams.get('provider_name')
const payStatus = searchParams.get('payment_result')
const [showPayStatusModal, setShowPayStatusModal] = useState(false)
const handleCancelShowPayStatusModal = useCallback(() => {
setShowPayStatusModal(false)
router.replace('/', { forceOptimisticNavigation: false })
}, [router])
useEffect(() => {
document.title = `${t('app.title')} - Dify`
if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') {
@ -41,35 +43,24 @@ const Apps = () => {
}
if (payProviderName === ProviderEnum.anthropic && (payStatus === 'succeeded' || payStatus === 'cancelled'))
setShowPayStatusModal(true)
}, [])
}, [mutate, payProviderName, payStatus, t])
useEffect(() => {
loadingStateRef.current = isLoading
}, [isLoading])
useEffect(() => {
const onScroll = debounce(() => {
if (!loadingStateRef.current) {
const { scrollTop, clientHeight } = pageContainerRef.current!
const anchorOffset = anchorRef.current!.offsetTop
if (anchorOffset - scrollTop - clientHeight < 100)
let observer: IntersectionObserver | undefined
if (anchorRef.current) {
observer = new IntersectionObserver((entries) => {
if (entries[0].isIntersecting)
setSize(size => size + 1)
}
}, 50)
pageContainerRef.current?.addEventListener('scroll', onScroll)
return () => pageContainerRef.current?.removeEventListener('scroll', onScroll)
}, [])
const handleCancelShowPayStatusModal = () => {
setShowPayStatusModal(false)
router.replace('/', { forceOptimisticNavigation: false })
}
}, { rootMargin: '100px' })
observer.observe(anchorRef.current)
}
return () => observer?.disconnect()
}, [isLoading, setSize, anchorRef, mutate])
return (
<nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
<><nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
{ isCurrentWorkspaceManager
&& <NewAppCard ref={anchorRef} onSuccess={mutate} />}
&& <NewAppCard onSuccess={mutate} />}
{data?.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
)))}
@ -95,6 +86,8 @@ const Apps = () => {
)
}
</nav>
<div ref={anchorRef} className='h-0'> </div>
</>
)
}

View File

@ -1,4 +1,4 @@
import React from "react";
import React from 'react'
import type { FC } from 'react'
import GA, { GaType } from '@/app/components/base/ga'
@ -6,13 +6,11 @@ const Layout: FC<{
children: React.ReactNode
}> = ({ children }) => {
return (
<div className="overflow-x-auto">
<div className="w-screen h-screen min-w-[300px]">
<GA gaType={GaType.webapp} />
{children}
</div>
<div className="min-w-[300px]">
<GA gaType={GaType.webapp} />
{children}
</div>
)
}
export default Layout
export default Layout

View File

@ -38,6 +38,7 @@ const config: ProviderConfig = {
defaultValue: {
model_type: 'text-generation',
huggingfacehub_api_type: 'hosted_inference_api',
task_type: 'text-generation',
},
validateKeys: (v?: FormValue) => {
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
@ -51,10 +52,36 @@ const config: ProviderConfig = {
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
]
}
return []
},
filterValue: (v?: FormValue) => {
let filteredKeys: string[] = []
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'model_type',
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
return prev
}, {})
},
fields: [
{
type: 'radio',
@ -120,6 +147,32 @@ const config: ProviderConfig = {
'zh-Hans': '在此输入您的端点 URL',
},
},
{
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
type: 'radio',
key: 'task_type',
required: true,
label: {
'en': 'Task',
'zh-Hans': 'Task',
},
options: [
{
key: 'text2text-generation',
label: {
'en': 'Text-to-Text Generation',
'zh-Hans': 'Text-to-Text Generation',
},
},
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': 'Text Generation',
},
},
],
},
],
},
}

View File

@ -91,6 +91,7 @@ export type ProviderConfigModal = {
icon: ReactElement
defaultValue?: FormValue
validateKeys?: string[] | ((v?: FormValue) => string[])
filterValue?: (v?: FormValue) => FormValue
fields: Field[]
link: {
href: string

View File

@ -124,8 +124,9 @@ const ModelPage = () => {
updateModelList(ModelType.embeddings)
mutateProviders()
}
const handleSave = async (v?: FormValue) => {
if (v && modelModalConfig) {
const handleSave = async (originValue?: FormValue) => {
if (originValue && modelModalConfig) {
const v = modelModalConfig.filterValue ? modelModalConfig.filterValue(originValue) : originValue
let body, url
if (ConfigurableProviders.includes(modelModalConfig.key)) {
const { model_name, model_type, ...config } = v

View File

@ -68,7 +68,7 @@ const Form: FC<FormProps> = ({
return true
},
run: () => {
return validateModelProviderFn(modelModal!.key, v)
return validateModelProviderFn(modelModal!.key, modelModal?.filterValue ? modelModal?.filterValue(v) : v)
},
})
}

View File

@ -574,7 +574,7 @@ const Main: FC<IMainProps> = ({
return <Loading type='app' />
return (
<div className='bg-gray-100'>
<div className='bg-gray-100 flex w-full h-full'>
{!isInstalledApp && (
<Header
title={siteInfo.title}
@ -588,7 +588,7 @@ const Main: FC<IMainProps> = ({
<div
className={cn(
'flex rounded-t-2xl bg-white overflow-hidden',
'flex rounded-t-2xl bg-white overflow-hidden h-full w-full',
isInstalledApp && 'rounded-b-2xl',
)}
style={isInstalledApp
@ -611,7 +611,7 @@ const Main: FC<IMainProps> = ({
)}
{/* main */}
<div className={cn(
isInstalledApp ? s.installedApp : 'h-[calc(100vh_-_3rem)]',
isInstalledApp ? s.installedApp : '',
'flex-grow flex flex-col overflow-y-auto',
)
}>

View File

@ -85,7 +85,7 @@ const Sidebar: FC<ISidebarProps> = ({
<div
className={
cn(
(isInstalledApp || isUniversalChat) ? 'tablet:h-[calc(100vh_-_74px)]' : 'tablet:h-[calc(100vh_-_3rem)]',
(isInstalledApp || isUniversalChat) ? 'tablet:h-[calc(100vh_-_74px)]' : '',
'shrink-0 flex flex-col bg-white pc:w-[244px] tablet:w-[192px] mobile:w-[240px] border-r border-gray-200 mobile:h-screen',
)
}

View File

@ -510,7 +510,7 @@ const Main: FC<IMainProps> = ({
<div className={'flex bg-white overflow-hidden'}>
<div className={cn(
isInstalledApp ? s.installedApp : 'h-[calc(100vh_-_3rem)]',
isInstalledApp ? s.installedApp : '',
'flex-grow flex flex-col overflow-y-auto',
)
}>

View File

@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "0.3.15",
"version": "0.3.18",
"private": true,
"scripts": {
"dev": "next dev",
@ -94,7 +94,6 @@
"@types/sortablejs": "^1.15.1",
"eslint": "8.36.0",
"eslint-config-next": "^13.4.7",
"eslint-plugin-react-hooks": "^4.6.0",
"husky": "^8.0.3",
"lint-staged": "^13.2.2",
"miragejs": "^0.1.47",

View File

@ -1838,11 +1838,6 @@ eslint-plugin-react-hooks@5.0.0-canary-7118f5dd7-20230705:
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-5.0.0-canary-7118f5dd7-20230705.tgz#4d55c50e186f1a2b0636433d2b0b2f592ddbccfd"
integrity sha512-AZYbMo/NW9chdL7vk6HQzQhT+PvTAEVqWk9ziruUoW2kAOcN5qNyelv70e0F1VNQAbvutOC9oc+xfWycI9FxDw==
eslint-plugin-react-hooks@^4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-4.6.0.tgz#4c3e697ad95b77e93f8646aaa1630c1ba607edd3"
integrity sha512-oFc7Itz9Qxh2x4gNHStv3BqJq54ExXmfC+a1NjAta66IAN87Wu0R/QArgIS9qKzX3dXKPI9H5crl9QchNMY9+g==
eslint-plugin-react@^7.31.7:
version "7.33.1"
resolved "https://registry.yarnpkg.com/eslint-plugin-react/-/eslint-plugin-react-7.33.1.tgz#bc27cccf860ae45413a4a4150bf0977345c1ceab"
@ -4510,7 +4505,7 @@ prismjs@~1.27.0:
resolved "https://registry.yarnpkg.com/prismjs/-/prismjs-1.27.0.tgz#bb6ee3138a0b438a3653dd4d6ce0cc6510a45057"
integrity sha512-t13BGPUlFDR7wRB5kQDG4jjl7XeuH6jbJGt11JHPL96qwsEHNX2+68tFXqc1/k+/jALsbSWJKUOT/hcYAZ5LkA==
prop-types@^15.0.0, prop-types@^15.8.1:
prop-types@^15.0.0, prop-types@^15.5.8, prop-types@^15.8.1:
version "15.8.1"
resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.8.1.tgz#67d87bf1a694f48435cf332c24af10214a3140b5"
integrity sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==
@ -4548,6 +4543,13 @@ queue-microtask@^1.2.2:
resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243"
integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==
react-18-input-autosize@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/react-18-input-autosize/-/react-18-input-autosize-3.0.0.tgz#eb34ac8c8335c30f76a56a8902d31f1fc1b62c4c"
integrity sha512-7tsUc9PJWg6Vsp8qYuzlKKBf7hbCoTBdNfjYZSprEPbxf3meuhjklg9QPBe9rIyoR3uDAzmG7NpoJ1+kP5ns+w==
dependencies:
prop-types "^15.5.8"
react-dom@^18.2.0:
version "18.2.0"
resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-18.2.0.tgz#22aaf38708db2674ed9ada224ca4aa708d821e3d"