mirror of
https://github.com/langgenius/dify.git
synced 2026-01-27 23:35:51 +08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e837dde1a | |||
| 9ae91a2ec3 | |||
| 276d3d10a0 | |||
| f13623184a | |||
| ef61e1487f | |||
| 701e2b334f | |||
| 6ebd6e7890 | |||
| bd3a9b2f8d | |||
| 18d3877151 | |||
| 53e83d8697 | |||
| 6377fc75c6 | |||
| 2c30d19cbe | |||
| 9b247fccd4 | |||
| 3d38aa7138 | |||
| 7d2552b3f2 | |||
| 117a209ad4 | |||
| 071e7800a0 | |||
| a76fde3d23 | |||
| 1fc57d7358 |
@ -100,7 +100,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.15"
|
||||
self.CURRENT_VERSION = "0.3.18"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@ -130,13 +130,12 @@ class Completion:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = cls.get_main_llm_prompt(
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=agent_execute_result,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
@ -154,113 +153,6 @@ class Completion:
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
query=query,
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
return [PromptMessage(content=prompt_content)], None
|
||||
else:
|
||||
messages: List[BaseMessage] = []
|
||||
|
||||
human_inputs = {
|
||||
"query": query
|
||||
}
|
||||
|
||||
human_message_prompt = ""
|
||||
|
||||
if pre_prompt:
|
||||
pre_prompt_inputs = {k: inputs[k] for k in
|
||||
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
"""
|
||||
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=human_message_prompt + query_prompt,
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
if memory.model_instance.model_rules.max_tokens.max:
|
||||
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
# construct main prompt
|
||||
human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=human_message_prompt,
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> str:
|
||||
@ -307,13 +199,12 @@ And answer according to the language of the user's question.
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
prompt_messages, _ = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@ -358,13 +249,12 @@ And answer according to the language of the user's question.
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
old_prompt_messages, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
model=app_model_config.model_dict,
|
||||
old_prompt_messages, _ = final_model_instance.get_prompt(
|
||||
mode='completion',
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
agent_execute_result=None,
|
||||
query=message.query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
i = 0
|
||||
normalized_embedding_results = []
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
normalized_embedding_results.append(normalized_embedding)
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
text_embeddings.extend(normalized_embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_query(text)
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results
|
||||
|
||||
|
||||
|
||||
@ -1,17 +1,24 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
import decimal
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = rules['price_config'][
|
||||
self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
|
||||
total_tokens = result.llm_output['token_usage']['total_tokens']
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(messages)
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
completion_tokens = self.get_num_tokens(
|
||||
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
self.model_provider.update_last_used()
|
||||
@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
def get_prompt(self, mode: str,
|
||||
pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
context=context
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
|
||||
|
||||
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
|
||||
|
||||
if memory and 'histories_prompt' in prompt_rules:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=prompt + query_prompt,
|
||||
inputs={
|
||||
'query': query
|
||||
}
|
||||
)
|
||||
|
||||
if self.model_rules.max_tokens.max:
|
||||
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = self.model_kwargs.max_tokens
|
||||
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
|
||||
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format(
|
||||
histories=histories
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
|
||||
elif order == 'histories_prompt':
|
||||
prompt += histories_prompt_content
|
||||
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format(
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt += query_prompt_content
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return prompt, stops
|
||||
|
||||
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
|
||||
'prompt/generate_prompts')
|
||||
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
client = HuggingFaceEndpointLLM(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
@ -62,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
|
||||
@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
credentials['task_type'] = 'text-generation'
|
||||
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
|
||||
13
api/core/prompt/generate_prompts/baichuan_chat.json
Normal file
13
api/core/prompt/generate_prompts/baichuan_chat.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"human_prefix": "用户",
|
||||
"assistant_prefix": "助手",
|
||||
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n",
|
||||
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "用户:{{query}}",
|
||||
"stops": ["用户:"]
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
{
|
||||
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt"
|
||||
],
|
||||
"query_prompt": "{{query}}",
|
||||
"stops": null
|
||||
}
|
||||
13
api/core/prompt/generate_prompts/common_chat.json
Normal file
13
api/core/prompt/generate_prompts/common_chat.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"human_prefix": "Human",
|
||||
"assistant_prefix": "Assistant",
|
||||
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
||||
"histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{histories}}\n</histories>\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "Human: {{query}}\n\nAssistant: ",
|
||||
"stops": ["\nHuman:", "</histories>"]
|
||||
}
|
||||
9
api/core/prompt/generate_prompts/common_completion.json
Normal file
9
api/core/prompt/generate_prompts/common_completion.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt"
|
||||
],
|
||||
"query_prompt": "{{query}}",
|
||||
"stops": null
|
||||
}
|
||||
39
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
vendored
Normal file
39
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
|
||||
"""HuggingFace Endpoint models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
endpoint_url = (
|
||||
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
hf = HuggingFaceEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
return values
|
||||
@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[
|
||||
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
@ -108,12 +108,12 @@ class XinferenceLLM(Xinference):
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(model, RESTfulGenerateModelHandle):
|
||||
streaming_response = model.generate(
|
||||
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
else:
|
||||
streaming_response = model.chat(
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
|
||||
@ -123,7 +123,16 @@ class XinferenceLLM(Xinference):
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
if 'finish_reason' in choice and choice['finish_reason'] \
|
||||
and choice['finish_reason'] in ['stop', 'length']:
|
||||
break
|
||||
|
||||
if 'text' in choice:
|
||||
token = choice.get("text", "")
|
||||
elif 'delta' in choice and 'content' in choice['delta']:
|
||||
token = choice.get('delta').get('content')
|
||||
else:
|
||||
continue
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
|
||||
@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.2.0
|
||||
xinference==0.2.1
|
||||
safetensors==0.3.2
|
||||
@ -19,7 +19,7 @@ from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
|
||||
@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
|
||||
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'valid_key',
|
||||
'huggingfacehub_endpoint_url': 'valid_url'
|
||||
'huggingfacehub_endpoint_url': 'valid_url',
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
|
||||
@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.15
|
||||
image: langgenius/dify-api:0.3.18
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -124,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.15
|
||||
image: langgenius/dify-api:0.3.18
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@ -176,7 +176,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.15
|
||||
image: langgenius/dify-web:0.3.18
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
{
|
||||
"extends": [
|
||||
"@antfu",
|
||||
"plugin:react-hooks/recommended"
|
||||
"next",
|
||||
"@antfu"
|
||||
],
|
||||
"rules": {
|
||||
"@typescript-eslint/consistent-type-definitions": [
|
||||
|
||||
2
web/.gitignore
vendored
2
web/.gitignore
vendored
@ -15,6 +15,8 @@
|
||||
# production
|
||||
/build
|
||||
|
||||
/.history
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import { debounce } from 'lodash-es'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppCard from './AppCard'
|
||||
import NewAppCard from './NewAppCard'
|
||||
import type { AppListResponse } from '@/models/app'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { useAppContext, useSelector } from '@/context/app-context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import Confirm from '@/app/components/base/confirm/common'
|
||||
@ -24,15 +23,18 @@ const Apps = () => {
|
||||
const { t } = useTranslation()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const { data, isLoading, setSize, mutate } = useSWRInfinite(getKey, fetchAppList, { revalidateFirstPage: false })
|
||||
const loadingStateRef = useRef(false)
|
||||
const pageContainerRef = useSelector(state => state.pageContainerRef)
|
||||
const anchorRef = useRef<HTMLAnchorElement>(null)
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const payProviderName = searchParams.get('provider_name')
|
||||
const payStatus = searchParams.get('payment_result')
|
||||
const [showPayStatusModal, setShowPayStatusModal] = useState(false)
|
||||
|
||||
const handleCancelShowPayStatusModal = useCallback(() => {
|
||||
setShowPayStatusModal(false)
|
||||
router.replace('/', { forceOptimisticNavigation: false })
|
||||
}, [router])
|
||||
|
||||
useEffect(() => {
|
||||
document.title = `${t('app.title')} - Dify`
|
||||
if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') {
|
||||
@ -41,35 +43,24 @@ const Apps = () => {
|
||||
}
|
||||
if (payProviderName === ProviderEnum.anthropic && (payStatus === 'succeeded' || payStatus === 'cancelled'))
|
||||
setShowPayStatusModal(true)
|
||||
}, [])
|
||||
}, [mutate, payProviderName, payStatus, t])
|
||||
|
||||
useEffect(() => {
|
||||
loadingStateRef.current = isLoading
|
||||
}, [isLoading])
|
||||
|
||||
useEffect(() => {
|
||||
const onScroll = debounce(() => {
|
||||
if (!loadingStateRef.current) {
|
||||
const { scrollTop, clientHeight } = pageContainerRef.current!
|
||||
const anchorOffset = anchorRef.current!.offsetTop
|
||||
if (anchorOffset - scrollTop - clientHeight < 100)
|
||||
let observer: IntersectionObserver | undefined
|
||||
if (anchorRef.current) {
|
||||
observer = new IntersectionObserver((entries) => {
|
||||
if (entries[0].isIntersecting)
|
||||
setSize(size => size + 1)
|
||||
}
|
||||
}, 50)
|
||||
|
||||
pageContainerRef.current?.addEventListener('scroll', onScroll)
|
||||
return () => pageContainerRef.current?.removeEventListener('scroll', onScroll)
|
||||
}, [])
|
||||
|
||||
const handleCancelShowPayStatusModal = () => {
|
||||
setShowPayStatusModal(false)
|
||||
router.replace('/', { forceOptimisticNavigation: false })
|
||||
}
|
||||
}, { rootMargin: '100px' })
|
||||
observer.observe(anchorRef.current)
|
||||
}
|
||||
return () => observer?.disconnect()
|
||||
}, [isLoading, setSize, anchorRef, mutate])
|
||||
|
||||
return (
|
||||
<nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
|
||||
<><nav className='grid content-start grid-cols-1 gap-4 px-12 pt-8 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 grow shrink-0'>
|
||||
{ isCurrentWorkspaceManager
|
||||
&& <NewAppCard ref={anchorRef} onSuccess={mutate} />}
|
||||
&& <NewAppCard onSuccess={mutate} />}
|
||||
{data?.map(({ data: apps }) => apps.map(app => (
|
||||
<AppCard key={app.id} app={app} onRefresh={mutate} />
|
||||
)))}
|
||||
@ -95,6 +86,8 @@ const Apps = () => {
|
||||
)
|
||||
}
|
||||
</nav>
|
||||
<div ref={anchorRef} className='h-0'> </div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import React from "react";
|
||||
import React from 'react'
|
||||
import type { FC } from 'react'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
|
||||
@ -6,13 +6,11 @@ const Layout: FC<{
|
||||
children: React.ReactNode
|
||||
}> = ({ children }) => {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<div className="w-screen h-screen min-w-[300px]">
|
||||
<GA gaType={GaType.webapp} />
|
||||
{children}
|
||||
</div>
|
||||
<div className="min-w-[300px]">
|
||||
<GA gaType={GaType.webapp} />
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Layout
|
||||
export default Layout
|
||||
|
||||
@ -38,6 +38,7 @@ const config: ProviderConfig = {
|
||||
defaultValue: {
|
||||
model_type: 'text-generation',
|
||||
huggingfacehub_api_type: 'hosted_inference_api',
|
||||
task_type: 'text-generation',
|
||||
},
|
||||
validateKeys: (v?: FormValue) => {
|
||||
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
|
||||
@ -51,10 +52,36 @@ const config: ProviderConfig = {
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'huggingfacehub_endpoint_url',
|
||||
'task_type',
|
||||
]
|
||||
}
|
||||
return []
|
||||
},
|
||||
filterValue: (v?: FormValue) => {
|
||||
let filteredKeys: string[] = []
|
||||
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
|
||||
filteredKeys = [
|
||||
'huggingfacehub_api_type',
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'model_type',
|
||||
]
|
||||
}
|
||||
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
|
||||
filteredKeys = [
|
||||
'huggingfacehub_api_type',
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'huggingfacehub_endpoint_url',
|
||||
'task_type',
|
||||
'model_type',
|
||||
]
|
||||
}
|
||||
return filteredKeys.reduce((prev: FormValue, next: string) => {
|
||||
prev[next] = v?.[next] || ''
|
||||
return prev
|
||||
}, {})
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
type: 'radio',
|
||||
@ -120,6 +147,32 @@ const config: ProviderConfig = {
|
||||
'zh-Hans': '在此输入您的端点 URL',
|
||||
},
|
||||
},
|
||||
{
|
||||
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
|
||||
type: 'radio',
|
||||
key: 'task_type',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'Task',
|
||||
'zh-Hans': 'Task',
|
||||
},
|
||||
options: [
|
||||
{
|
||||
key: 'text2text-generation',
|
||||
label: {
|
||||
'en': 'Text-to-Text Generation',
|
||||
'zh-Hans': 'Text-to-Text Generation',
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'text-generation',
|
||||
label: {
|
||||
'en': 'Text Generation',
|
||||
'zh-Hans': 'Text Generation',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@ -91,6 +91,7 @@ export type ProviderConfigModal = {
|
||||
icon: ReactElement
|
||||
defaultValue?: FormValue
|
||||
validateKeys?: string[] | ((v?: FormValue) => string[])
|
||||
filterValue?: (v?: FormValue) => FormValue
|
||||
fields: Field[]
|
||||
link: {
|
||||
href: string
|
||||
|
||||
@ -124,8 +124,9 @@ const ModelPage = () => {
|
||||
updateModelList(ModelType.embeddings)
|
||||
mutateProviders()
|
||||
}
|
||||
const handleSave = async (v?: FormValue) => {
|
||||
if (v && modelModalConfig) {
|
||||
const handleSave = async (originValue?: FormValue) => {
|
||||
if (originValue && modelModalConfig) {
|
||||
const v = modelModalConfig.filterValue ? modelModalConfig.filterValue(originValue) : originValue
|
||||
let body, url
|
||||
if (ConfigurableProviders.includes(modelModalConfig.key)) {
|
||||
const { model_name, model_type, ...config } = v
|
||||
|
||||
@ -68,7 +68,7 @@ const Form: FC<FormProps> = ({
|
||||
return true
|
||||
},
|
||||
run: () => {
|
||||
return validateModelProviderFn(modelModal!.key, v)
|
||||
return validateModelProviderFn(modelModal!.key, modelModal?.filterValue ? modelModal?.filterValue(v) : v)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@ -574,7 +574,7 @@ const Main: FC<IMainProps> = ({
|
||||
return <Loading type='app' />
|
||||
|
||||
return (
|
||||
<div className='bg-gray-100'>
|
||||
<div className='bg-gray-100 flex w-full h-full'>
|
||||
{!isInstalledApp && (
|
||||
<Header
|
||||
title={siteInfo.title}
|
||||
@ -588,7 +588,7 @@ const Main: FC<IMainProps> = ({
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
'flex rounded-t-2xl bg-white overflow-hidden',
|
||||
'flex rounded-t-2xl bg-white overflow-hidden h-full w-full',
|
||||
isInstalledApp && 'rounded-b-2xl',
|
||||
)}
|
||||
style={isInstalledApp
|
||||
@ -611,7 +611,7 @@ const Main: FC<IMainProps> = ({
|
||||
)}
|
||||
{/* main */}
|
||||
<div className={cn(
|
||||
isInstalledApp ? s.installedApp : 'h-[calc(100vh_-_3rem)]',
|
||||
isInstalledApp ? s.installedApp : '',
|
||||
'flex-grow flex flex-col overflow-y-auto',
|
||||
)
|
||||
}>
|
||||
|
||||
@ -85,7 +85,7 @@ const Sidebar: FC<ISidebarProps> = ({
|
||||
<div
|
||||
className={
|
||||
cn(
|
||||
(isInstalledApp || isUniversalChat) ? 'tablet:h-[calc(100vh_-_74px)]' : 'tablet:h-[calc(100vh_-_3rem)]',
|
||||
(isInstalledApp || isUniversalChat) ? 'tablet:h-[calc(100vh_-_74px)]' : '',
|
||||
'shrink-0 flex flex-col bg-white pc:w-[244px] tablet:w-[192px] mobile:w-[240px] border-r border-gray-200 mobile:h-screen',
|
||||
)
|
||||
}
|
||||
|
||||
@ -510,7 +510,7 @@ const Main: FC<IMainProps> = ({
|
||||
|
||||
<div className={'flex bg-white overflow-hidden'}>
|
||||
<div className={cn(
|
||||
isInstalledApp ? s.installedApp : 'h-[calc(100vh_-_3rem)]',
|
||||
isInstalledApp ? s.installedApp : '',
|
||||
'flex-grow flex flex-col overflow-y-auto',
|
||||
)
|
||||
}>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"version": "0.3.15",
|
||||
"version": "0.3.18",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
@ -94,7 +94,6 @@
|
||||
"@types/sortablejs": "^1.15.1",
|
||||
"eslint": "8.36.0",
|
||||
"eslint-config-next": "^13.4.7",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.2.2",
|
||||
"miragejs": "^0.1.47",
|
||||
|
||||
@ -1838,11 +1838,6 @@ eslint-plugin-react-hooks@5.0.0-canary-7118f5dd7-20230705:
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-5.0.0-canary-7118f5dd7-20230705.tgz#4d55c50e186f1a2b0636433d2b0b2f592ddbccfd"
|
||||
integrity sha512-AZYbMo/NW9chdL7vk6HQzQhT+PvTAEVqWk9ziruUoW2kAOcN5qNyelv70e0F1VNQAbvutOC9oc+xfWycI9FxDw==
|
||||
|
||||
eslint-plugin-react-hooks@^4.6.0:
|
||||
version "4.6.0"
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-4.6.0.tgz#4c3e697ad95b77e93f8646aaa1630c1ba607edd3"
|
||||
integrity sha512-oFc7Itz9Qxh2x4gNHStv3BqJq54ExXmfC+a1NjAta66IAN87Wu0R/QArgIS9qKzX3dXKPI9H5crl9QchNMY9+g==
|
||||
|
||||
eslint-plugin-react@^7.31.7:
|
||||
version "7.33.1"
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react/-/eslint-plugin-react-7.33.1.tgz#bc27cccf860ae45413a4a4150bf0977345c1ceab"
|
||||
@ -4510,7 +4505,7 @@ prismjs@~1.27.0:
|
||||
resolved "https://registry.yarnpkg.com/prismjs/-/prismjs-1.27.0.tgz#bb6ee3138a0b438a3653dd4d6ce0cc6510a45057"
|
||||
integrity sha512-t13BGPUlFDR7wRB5kQDG4jjl7XeuH6jbJGt11JHPL96qwsEHNX2+68tFXqc1/k+/jALsbSWJKUOT/hcYAZ5LkA==
|
||||
|
||||
prop-types@^15.0.0, prop-types@^15.8.1:
|
||||
prop-types@^15.0.0, prop-types@^15.5.8, prop-types@^15.8.1:
|
||||
version "15.8.1"
|
||||
resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.8.1.tgz#67d87bf1a694f48435cf332c24af10214a3140b5"
|
||||
integrity sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==
|
||||
@ -4548,6 +4543,13 @@ queue-microtask@^1.2.2:
|
||||
resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243"
|
||||
integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==
|
||||
|
||||
react-18-input-autosize@^3.0.0:
|
||||
version "3.0.0"
|
||||
resolved "https://registry.yarnpkg.com/react-18-input-autosize/-/react-18-input-autosize-3.0.0.tgz#eb34ac8c8335c30f76a56a8902d31f1fc1b62c4c"
|
||||
integrity sha512-7tsUc9PJWg6Vsp8qYuzlKKBf7hbCoTBdNfjYZSprEPbxf3meuhjklg9QPBe9rIyoR3uDAzmG7NpoJ1+kP5ns+w==
|
||||
dependencies:
|
||||
prop-types "^15.5.8"
|
||||
|
||||
react-dom@^18.2.0:
|
||||
version "18.2.0"
|
||||
resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-18.2.0.tgz#22aaf38708db2674ed9ada224ca4aa708d821e3d"
|
||||
|
||||
Reference in New Issue
Block a user