Compare commits

...

19 Commits
0.4.3 ... 0.4.4

Author SHA1 Message Date
9f58912fd7 bump version to 0.4.4 (#1962) 2024-01-06 03:08:05 +08:00
0c746f5c5a fix: generate not stop when pressing stop link (#1961) 2024-01-06 03:03:56 +08:00
a8cedea15a fix: check result should be string. (#1959) 2024-01-05 22:11:51 +08:00
87832ede17 delete remnant 'required': false (#1955) 2024-01-05 19:18:33 +08:00
4d99c689f0 prohibit enable and disable function when segment is not completed (#1954)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-01-05 18:18:38 +08:00
28b26f67e2 optimize qa prompt (#1957)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-05 18:17:55 +08:00
b934232411 change API key field to 'required' (#1953) 2024-01-05 17:19:04 +08:00
2f120786fd feat: reorder togetherai (#1951) 2024-01-05 17:04:37 +08:00
6075fee556 Add Together.ai's OpenAI API-compatible inference endpoints (#1947) 2024-01-05 16:36:29 +08:00
de584807e1 fix streaming (#1944) 2024-01-05 01:03:54 -06:00
a1285cbf15 fix: text-generation run batch (#1945) 2024-01-05 14:47:00 +08:00
cf1f6f3961 fix: text completion app cannot get data. (#1942) 2024-01-05 14:29:01 +08:00
f4d97ef9fa fix: arg user required and must not be null in service generate api (#1943) 2024-01-05 14:28:03 +08:00
28883e80d4 fix: gpt-4-32k model name empty in OpenAI response (#1941) 2024-01-05 12:49:26 +08:00
a0f74cdd9d fix: llm result usage none (#1940) 2024-01-05 12:47:10 +08:00
296bf443a8 feat: reuse decoding_rsa_key & decoding_cipher_rsa & optimize construct (#1937) 2024-01-05 12:13:45 +08:00
af7be9bdd7 Feat/optimize entity construct (#1935) 2024-01-05 09:43:41 +08:00
2cfd5568e1 fix: vision fail in complete app (#1933) 2024-01-05 04:23:12 +08:00
faf40a42bc feat: optimize memory & invoke error output (#1931) 2024-01-05 03:47:46 +08:00
49 changed files with 547 additions and 257 deletions

View File

@ -88,7 +88,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.4.3"
self.CURRENT_VERSION = "0.4.4"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@ -156,6 +156,9 @@ class DatasetDocumentSegmentApi(Resource):
if not segment:
raise NotFound('Segment not found.')
if segment.status != 'completed':
raise NotFound('Segment is not completed, enable or disable function is not allowed')
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:

View File

@ -31,7 +31,7 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
@ -96,7 +96,7 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')

View File

@ -1,7 +1,7 @@
import time
from typing import cast, Optional, List, Tuple, Generator, Union
from core.application_queue_manager import ApplicationQueueManager
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
@ -183,7 +183,7 @@ class AppRunner:
index=index,
message=AssistantPromptMessage(content=token)
)
))
), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
@ -193,7 +193,8 @@ class AppRunner:
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
)
),
pub_from=PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
@ -226,7 +227,8 @@ class AppRunner:
:return:
"""
queue_manager.publish_message_end(
llm_result=invoke_result
llm_result=invoke_result,
pub_from=PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
@ -242,7 +244,7 @@ class AppRunner:
text = ''
usage = None
for result in invoke_result:
queue_manager.publish_chunk_message(result)
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
text += result.delta.message.content
@ -263,5 +265,6 @@ class AppRunner:
)
queue_manager.publish_message_end(
llm_result=llm_result
llm_result=llm_result,
pub_from=PublishFrom.APPLICATION_MANAGER
)

View File

@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
from core.application_queue_manager import ApplicationQueueManager
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.features.annotation_reply import AnnotationReplyFeature
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.external_data_fetch import ExternalDataFetchFeature
@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
@ -132,16 +133,16 @@ class BasicApplicationRunner(AppRunner):
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get context from datasets
context = None

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
from core.entities.application_entities import ApplicationGenerateEntity
from core.application_queue_manager import ApplicationQueueManager
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
AnnotationReplyEvent
@ -312,8 +312,11 @@ class GenerateTaskPipeline:
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
))
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
), PublishFrom.TASK_PIPELINE)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
continue
else:
self._output_moderation_handler.append_new_token(delta_text)

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
from flask import current_app, Flask
from pydantic import BaseModel
from core.application_queue_manager import PublishFrom
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
return final_output

View File

@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, Conversation, Message, MessageFile, App
@ -169,15 +169,18 @@ class ApplicationManager:
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e)
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e)
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e)
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()

View File

@ -1,5 +1,6 @@
import queue
import time
from enum import Enum
from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta
@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
from models.model import MessageAgentThought
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class ApplicationQueueManager:
def __init__(self, task_id: str,
user_id: str,
@ -61,11 +67,14 @@ class ApplicationQueueManager:
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
)
self.stop_listen()
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent())
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
@ -75,76 +84,83 @@ class ApplicationQueueManager:
"""
self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
"""
Publish chunk message to channel
:param chunk: chunk
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageEvent(
chunk=chunk
))
), pub_from)
def publish_message_replace(self, text: str) -> None:
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
"""
Publish message replace
:param text: text
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageReplaceEvent(
text=text
))
), pub_from)
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
"""
Publish retriever resources
:return:
"""
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
def publish_annotation_reply(self, message_annotation_id: str) -> None:
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
"""
Publish annotation reply
:param message_annotation_id: message annotation id
:param pub_from: publish from
:return:
"""
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
def publish_message_end(self, llm_result: LLMResult) -> None:
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
"""
Publish message end
:param llm_result: llm result
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageEndEvent(llm_result=llm_result))
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
"""
Publish agent thought
:param message_agent_thought: message agent thought
:param pub_from: publish from
:return:
"""
self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id
))
), pub_from)
def publish_error(self, e) -> None:
def publish_error(self, e, pub_from: PublishFrom) -> None:
"""
Publish error
:param e: error
:param pub_from: publish from
:return:
"""
self.publish(QueueErrorEvent(
error=e
))
), pub_from)
self.stop_listen()
def publish(self, event: AppQueueEvent) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
@ -162,6 +178,9 @@ class ApplicationQueueManager:
if isinstance(event, QueueStopEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise ConversationTaskStoppedException()
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
"""
@ -187,7 +206,6 @@ class ApplicationQueueManager:
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key)
if result is not None:
redis_client.delete(stopped_cache_key)
return True
return False

View File

@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.application_queue_manager import ApplicationQueueManager
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.entity.agent_loop import AgentLoop
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
db.session.add(message_agent_thought)
db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought)
self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
return message_agent_thought

View File

@ -2,7 +2,7 @@ from typing import List, Union
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import DocumentSegment, DatasetQuery
@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish_retriever_resources(resource)
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)

View File

@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
provider_models.extend(
[
ModelWithProviderEntity(
**m.dict(),
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
for m in models:
provider_models.append(
ModelWithProviderEntity(
**m.dict(),
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
)
@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
provider_models.append(
ModelWithProviderEntity(
**custom_model_schema.dict(),
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)

View File

@ -58,7 +58,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(self.config.get('variable')))
.format(self.variable))
# decrypt api_key
api_key = encrypter.decrypt_token(
@ -74,7 +74,7 @@ class ApiExternalDataTool(ExternalDataTool):
)
except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
self.config.get('variable'),
self.variable,
e
))
@ -87,6 +87,10 @@ class ApiExternalDataTool(ExternalDataTool):
if 'result' not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
.format(self.config.get('variable')))
.format(self.variable))
if not isinstance(response_json['result'], str):
raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string"
.format(self.variable))
return response_json['result']

View File

@ -40,7 +40,7 @@ class ProviderCredentialsCache:
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 3600, json.dumps(credentials))
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""

View File

@ -8,6 +8,9 @@ class InvokeError(Exception):
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""

View File

@ -148,7 +148,9 @@ class AIModel(ABC):
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, 'r', encoding='utf-8') as f:
position_map = yaml.safe_load(f)
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:

View File

@ -165,7 +165,7 @@ class LargeLanguageModel(AIModel):
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage,
usage=usage if usage else LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint
),
credentials=credentials,

View File

@ -112,7 +112,7 @@ class ModelProvider(ABC):
model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel):
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break

View File

@ -1,19 +1,20 @@
openai: 0
anthropic: 1
azure_openai: 2
google: 3
replicate: 4
huggingface_hub: 5
cohere: 6
zhipuai: 7
baichuan: 8
spark: 9
minimax: 10
tongyi: 11
wenxin: 12
jina: 13
chatglm: 14
xinference: 15
openllm: 16
localai: 17
openai_api_compatible: 18
- openai
- anthropic
- azure_openai
- google
- replicate
- huggingface_hub
- cohere
- togetherai
- zhipuai
- baichuan
- spark
- minimax
- tongyi
- wenxin
- jina
- chatglm
- xinference
- openllm
- localai
- openai_api_compatible

View File

@ -309,7 +309,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
# transform response
response = LLMResult(
model=response.model,
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,

View File

@ -217,7 +217,9 @@ class ModelProviderFactory:
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, 'r', encoding='utf-8') as f:
position_map = yaml.safe_load(f)
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_provider_dir_paths
for model_provider_dir_path in model_provider_dir_paths:

View File

@ -1,9 +1,11 @@
gpt-4: 0
gpt-4-32k: 1
gpt-4-1106-preview: 2
gpt-4-vision-preview: 3
gpt-3.5-turbo: 4
gpt-3.5-turbo-16k: 5
gpt-3.5-turbo-1106: 6
gpt-3.5-turbo-instruct: 7
text-davinci-003: 8
- gpt-4
- gpt-4-32k
- gpt-4-1106-preview
- gpt-4-vision-preview
- gpt-3.5-turbo
- gpt-3.5-turbo-16k
- gpt-3.5-turbo-16k-0613
- gpt-3.5-turbo-1106
- gpt-3.5-turbo-0613
- gpt-3.5-turbo-instruct
- text-davinci-003

View File

@ -40,87 +40,4 @@ class _CommonOAI_API_Compat:
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=model_type,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000),
ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=float(credentials.get('temperature', 1)),
min=0,
max=2
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
default=float(credentials.get('top_p', 1)),
min=0,
max=1
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get('top_k', 1)),
min=1,
max=100
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('frequency_penalty', 0)),
min=-2,
max=2
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="PRESENCE Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('PRESENCE_penalty', 0)),
min=-2,
max=2
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
default=1024,
min=1,
max=int(credentials.get('max_tokens_to_sample', 4096)),
)
],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
if model_type == ModelType.LLM:
if credentials['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity
}

View File

@ -158,7 +158,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
ModelPropertyKey.MODE: credentials.get('mode'),
},
parameter_rules=[
@ -196,9 +196,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="PRESENCE Penalty"),
label=I18nObject(en_US="Presence Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('PRESENCE_penalty', 0)),
default=float(credentials.get('presence_penalty', 0)),
min=-2,
max=2
),
@ -219,6 +219,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
)
if credentials['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
@ -261,7 +268,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION:
elif completion_type is LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content
else:
@ -291,10 +298,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
stream=stream
)
# Debug: Print request headers and json data
logger.debug(f"Request headers: {headers}")
logger.debug(f"Request JSON data: {data}")
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
@ -337,9 +340,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
)
for chunk in response.iter_content(chunk_size=2048):
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'):
if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
@ -356,7 +359,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
continue
choice = chunk_json['choices'][0]
chunk_index = choice['index'] if 'index' in choice else chunk_index
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
@ -408,12 +411,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=assistant_prompt_message,
)
)
else:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="End of stream."
)
chunk_index += 1

View File

@ -2,8 +2,8 @@ provider: openai_api_compatible
label:
en_US: OpenAI-API-compatible
description:
en_US: All model providers compatible with OpenAI's API standard, such as Together.ai.
zh_Hans: 兼容 OpenAI API 的模型供应商,例如 Together.ai
en_US: Model providers compatible with OpenAI's API standard, such as LM Studio.
zh_Hans: 兼容 OpenAI API 的模型供应商,例如 LM Studio
supported_model_types:
- llm
- text-embedding

View File

@ -112,7 +112,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,

View File

@ -0,0 +1,13 @@
<svg width="114" height="24" viewBox="0 0 114 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.21688 7.55431H1V5.74708H3.21688V2.30127H5.19279V5.74708H8.30124V7.55431H5.19279V14.8074C5.19279 15.3214 5.28918 15.6909 5.48195 15.9158C5.69079 16.1246 6.0442 16.2291 6.5422 16.2291H8.68679V18.0363H6.42171C5.26507 18.0363 4.43776 17.7792 3.93977 17.2652C3.45784 16.7511 3.21688 15.9398 3.21688 14.8314V7.55431Z" fill="black"/>
<path d="M15.0554 18.1809C13.8667 18.1809 12.8064 17.9159 11.8747 17.3857C10.959 16.8556 10.2441 16.1166 9.73006 15.1689C9.21601 14.2211 8.95898 13.1287 8.95898 11.8918C8.95898 10.6548 9.21601 9.5624 9.73006 8.6146C10.2441 7.6668 10.959 6.92785 11.8747 6.39772C12.8064 5.8676 13.8667 5.60254 15.0554 5.60254C16.2442 5.60254 17.2964 5.8676 18.212 6.39772C19.1438 6.92785 19.8667 7.6668 20.3807 8.6146C20.8948 9.5624 21.1518 10.6548 21.1518 11.8918C21.1518 13.1287 20.8948 14.2211 20.3807 15.1689C19.8667 16.1166 19.1438 16.8556 18.212 17.3857C17.2964 17.9159 16.2442 18.1809 15.0554 18.1809ZM15.0554 16.4219C15.8586 16.4219 16.5654 16.2291 17.1759 15.8435C17.8023 15.458 18.2844 14.9199 18.6216 14.2291C18.959 13.5383 19.1277 12.7592 19.1277 11.8918C19.1277 11.0242 18.959 10.2451 18.6216 9.55437C18.2844 8.86359 17.8023 8.32545 17.1759 7.9399C16.5654 7.55436 15.8586 7.36159 15.0554 7.36159C14.2521 7.36159 13.5373 7.55436 12.9108 7.9399C12.3004 8.32545 11.8265 8.86359 11.4891 9.55437C11.1518 10.2451 10.9831 11.0242 10.9831 11.8918C10.9831 12.7592 11.1518 13.5383 11.4891 14.2291C11.8265 14.9199 12.3004 15.458 12.9108 15.8435C13.5373 16.2291 14.2521 16.4219 15.0554 16.4219Z" fill="black"/>
<path d="M34.6823 5.74712V17.4339C34.6823 21.1448 32.6503 23.0002 28.5859 23.0002C26.9956 23.0002 25.6944 22.6388 24.6823 21.9158C23.6863 21.193 23.108 20.1649 22.9474 18.8315H24.9715C25.1322 19.6025 25.5418 20.197 26.2004 20.6146C26.8591 21.0323 27.7024 21.2411 28.7305 21.2411C31.3811 21.2411 32.7065 19.948 32.7065 17.3617V15.9159C31.823 17.4259 30.4173 18.1809 28.4896 18.1809C27.349 18.1809 26.3289 17.9319 25.4293 17.4339C24.5458 16.9359 23.847 16.213 23.3329 15.2652C22.8349 14.3174 22.5859 13.193 22.5859 11.8918C22.5859 10.6548 22.8349 9.5624 23.3329 8.6146C23.847 7.6668 24.5538 6.92785 25.4534 6.39772C26.3531 5.8676 27.365 5.60254 28.4896 5.60254C29.4855 5.60254 30.337 5.80334 31.0438 6.20495C31.7507 6.5905 32.3049 7.14472 32.7065 7.86761L32.9715 5.74712H34.6823ZM28.6824 16.4219C29.4695 16.4219 30.1683 16.2371 30.7787 15.8677C31.4053 15.4821 31.8872 14.9519 32.2246 14.2772C32.5618 13.5865 32.7306 12.8074 32.7306 11.9399C32.7306 11.0564 32.5618 10.2692 32.2246 9.57846C31.8872 8.87163 31.4053 8.32545 30.7787 7.9399C30.1683 7.55436 29.4695 7.36159 28.6824 7.36159C27.4615 7.36159 26.4735 7.78729 25.7185 8.63869C24.9795 9.47404 24.61 10.5584 24.61 11.8918C24.61 13.2251 24.9795 14.3174 25.7185 15.1689C26.4735 16.0042 27.4615 16.4219 28.6824 16.4219Z" fill="black"/>
<path d="M36.5449 11.8918C36.5449 10.6387 36.7859 9.5383 37.2678 8.5905C37.7658 7.64271 38.4565 6.91179 39.3401 6.39772C40.2236 5.8676 41.2357 5.60254 42.3763 5.60254C43.5007 5.60254 44.4968 5.83547 45.3642 6.30133C46.2317 6.7672 46.9144 7.4419 47.4124 8.32545C47.9104 9.20898 48.1755 10.2451 48.2076 11.4339C48.2076 11.6106 48.1915 11.8918 48.1594 12.2772H38.6172V12.446C38.6493 13.6507 39.0187 14.6146 39.7256 15.3375C40.4324 16.0605 41.3562 16.4219 42.4967 16.4219C43.3802 16.4219 44.1272 16.205 44.7377 15.7712C45.3642 15.3215 45.7818 14.703 45.9908 13.9158H47.9907C47.7497 15.1689 47.1473 16.197 46.1834 17.0001C45.2196 17.7873 44.0389 18.1809 42.6412 18.1809C41.4204 18.1809 40.3521 17.9239 39.4365 17.4098C38.5208 16.8797 37.806 16.1408 37.2919 15.1929C36.7939 14.2291 36.5449 13.1287 36.5449 11.8918ZM46.1594 10.6387C46.063 9.59452 45.6694 8.78328 44.9787 8.20496C44.304 7.62664 43.4445 7.33749 42.4003 7.33749C41.4686 7.33749 40.6493 7.64271 39.9425 8.25315C39.2357 8.86359 38.8341 9.65878 38.7376 10.6387H46.1594Z" fill="black"/>
<path d="M50.7442 7.55431H48.5273V5.74708H50.7442V2.30127H52.7201V5.74708H55.8285V7.55431H52.7201V14.8074C52.7201 15.3214 52.8165 15.6909 53.0093 15.9158C53.2181 16.1246 53.5715 16.2291 54.0696 16.2291H56.2141V18.0363H53.9491C52.7924 18.0363 51.9651 17.7792 51.4671 17.2652C50.9851 16.7511 50.7442 15.9398 50.7442 14.8314V7.55431Z" fill="black"/>
<path d="M63.2468 5.6027C64.7408 5.6027 65.9456 6.0525 66.8613 6.95211C67.7769 7.8517 68.2348 9.26536 68.2348 11.1931V18.0365H66.2589V11.3136C66.2589 10.0445 65.9697 9.08062 65.3914 8.42199C64.8131 7.74729 63.9858 7.40994 62.9095 7.40994C61.7689 7.40994 60.8613 7.81154 60.1866 8.61476C59.5279 9.41798 59.1986 10.5103 59.1986 11.8919V18.0365H57.2227V1.16895H59.1986V7.77139C59.6002 7.12881 60.1303 6.60672 60.789 6.20511C61.4637 5.8035 62.283 5.6027 63.2468 5.6027Z" fill="black"/>
<path d="M69.9258 11.8918C69.9258 10.6387 70.1667 9.5383 70.6486 8.5905C71.1467 7.64271 71.8374 6.91179 72.721 6.39772C73.6045 5.8676 74.6165 5.60254 75.7571 5.60254C76.8816 5.60254 77.8776 5.83547 78.7451 6.30133C79.6126 6.7672 80.2953 7.4419 80.7933 8.32545C81.2912 9.20898 81.5563 10.2451 81.5885 11.4339C81.5885 11.6106 81.5723 11.8918 81.5403 12.2772H71.998V12.446C72.0302 13.6507 72.3996 14.6146 73.1064 15.3375C73.8133 16.0605 74.737 16.4219 75.8776 16.4219C76.7611 16.4219 77.5081 16.205 78.1186 15.7712C78.7451 15.3215 79.1627 14.703 79.3715 13.9158H81.3715C81.1306 15.1689 80.5282 16.197 79.5643 17.0001C78.6005 17.7873 77.4198 18.1809 76.0221 18.1809C74.8012 18.1809 73.733 17.9239 72.8173 17.4098C71.9017 16.8797 71.1868 16.1408 70.6728 15.1929C70.1747 14.2291 69.9258 13.1287 69.9258 11.8918ZM79.5403 10.6387C79.4438 9.59452 79.0502 8.78328 78.3595 8.20496C77.6848 7.62664 76.8254 7.33749 75.7811 7.33749C74.8495 7.33749 74.0302 7.64271 73.3234 8.25315C72.6165 8.86359 72.2149 9.65878 72.1185 10.6387H79.5403Z" fill="black"/>
<path d="M89.6864 5.74707V7.67478H88.6984C87.5257 7.67478 86.6823 8.06836 86.1682 8.85551C85.6703 9.64266 85.4212 10.6146 85.4212 11.7712V18.0363H83.4453V5.74707H85.1562L85.4212 7.6025C85.7746 7.04024 86.2325 6.59045 86.7947 6.25309C87.357 5.91575 88.1361 5.74707 89.1321 5.74707H89.6864Z" fill="black"/>
<path d="M109.812 16.2291V18.0364H108.726C107.939 18.0364 107.378 17.8757 107.04 17.5543C106.703 17.2331 106.526 16.7592 106.51 16.1327C105.562 17.4982 104.189 18.1809 102.39 18.1809C101.024 18.1809 99.9237 17.8596 99.0883 17.2171C98.269 16.5745 97.8594 15.6989 97.8594 14.5905C97.8594 13.3536 98.2771 12.4058 99.1124 11.7471C99.9637 11.0885 101.193 10.7592 102.799 10.7592H106.414V9.9158C106.414 9.11259 106.14 8.48608 105.594 8.03628C105.064 7.58648 104.317 7.36159 103.353 7.36159C102.502 7.36159 101.795 7.55436 101.233 7.9399C100.687 8.30937 100.349 8.80737 100.221 9.43388H98.2449C98.3894 8.22906 98.9196 7.28929 99.8353 6.61459C100.767 5.93989 101.972 5.60254 103.45 5.60254C105.024 5.60254 106.237 5.98808 107.088 6.75917C107.955 7.5142 108.39 8.60657 108.39 10.0363V15.3375C108.39 15.9319 108.662 16.2291 109.209 16.2291H109.812ZM106.414 12.4218H102.606C100.775 12.4218 99.8594 13.1045 99.8594 14.47C99.8594 15.0805 100.1 15.5704 100.582 15.9399C101.064 16.3094 101.715 16.4942 102.534 16.4942C103.739 16.4942 104.687 16.1809 105.377 15.5544C106.068 14.9118 106.414 14.0684 106.414 13.0242V12.4218Z" fill="black"/>
<path d="M111.922 1C112.291 1 112.597 1.12048 112.837 1.36145C113.079 1.60241 113.199 1.90763 113.199 2.27711C113.199 2.64659 113.079 2.95182 112.837 3.19278C112.597 3.43374 112.291 3.55423 111.922 3.55423C111.552 3.55423 111.247 3.43374 111.007 3.19278C110.765 2.95182 110.645 2.64659 110.645 2.27711C110.645 1.90763 110.765 1.60241 111.007 1.36145C111.247 1.12048 111.552 1 111.922 1ZM110.934 5.74701H112.91V18.0362H110.934V5.74701Z" fill="black"/>
<path d="M93.9949 16.1652C93.9949 17.1986 93.1469 18.0364 92.1009 18.0364C91.055 18.0364 90.207 17.1986 90.207 16.1652C90.207 15.1317 91.055 14.2939 92.1009 14.2939C93.1469 14.2939 93.9949 15.1317 93.9949 16.1652Z" fill="#0F6FFF"/>
</svg>

After

Width:  |  Height:  |  Size: 7.8 KiB

View File

@ -0,0 +1,19 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_15960_46917)">
<mask id="mask0_15960_46917" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="0" width="16" height="16">
<path d="M16 0H0V16H16V0Z" fill="white"/>
</mask>
<g mask="url(#mask0_15960_46917)">
<path d="M13.1765 0H2.82353C1.26414 0 0 1.26414 0 2.82353V13.1765C0 14.7359 1.26414 16 2.82353 16H13.1765C14.7359 16 16 14.7359 16 13.1765V2.82353C16 1.26414 14.7359 0 13.1765 0Z" fill="#F1EFED"/>
<path d="M11.4119 7.64706C12.9713 7.64706 14.2354 6.38292 14.2354 4.82353C14.2354 3.26414 12.9713 2 11.4119 2C9.85252 2 8.58838 3.26414 8.58838 4.82353C8.58838 6.38292 9.85252 7.64706 11.4119 7.64706Z" fill="#D3D1D1"/>
<path d="M11.4119 14.2354C12.9713 14.2354 14.2354 12.9713 14.2354 11.4119C14.2354 9.85252 12.9713 8.58838 11.4119 8.58838C9.85252 8.58838 8.58838 9.85252 8.58838 11.4119C8.58838 12.9713 9.85252 14.2354 11.4119 14.2354Z" fill="#D3D1D1"/>
<path d="M4.82353 14.2354C6.38292 14.2354 7.64706 12.9713 7.64706 11.4119C7.64706 9.85252 6.38292 8.58838 4.82353 8.58838C3.26414 8.58838 2 9.85252 2 11.4119C2 12.9713 3.26414 14.2354 4.82353 14.2354Z" fill="#D3D1D1"/>
<path d="M4.82353 7.64706C6.38292 7.64706 7.64706 6.38292 7.64706 4.82353C7.64706 3.26414 6.38292 2 4.82353 2C3.26414 2 2 3.26414 2 4.82353C2 6.38292 3.26414 7.64706 4.82353 7.64706Z" fill="#0F6FFF"/>
</g>
</g>
<defs>
<clipPath id="clip0_15960_46917">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@ -0,0 +1,45 @@
from typing import Generator, List, Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _update_endpoint_url(self, credentials: dict):
credentials['endpoint_url'] = "https://api.together.xyz/v1"
return credentials
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().validate_credentials(model, cred_with_endpoint)
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_customizable_model_schema(model, cred_with_endpoint)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

View File

@ -0,0 +1,13 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class TogetherAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,75 @@
provider: togetherai
label:
en_US: together.ai
icon_small:
en_US: togetherai_square.svg
icon_large:
en_US: togetherai.svg
background: "#F1EFED"
help:
title:
en_US: Get your API key from together.ai
zh_Hans: 从 together.ai 获取 API Key
url:
en_US: https://api.together.xyz/
supported_model_types:
- llm
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter full model name
zh_Hans: 输入模型全称
credential_form_schemas:
- variable: api_key
required: true
label:
en_US: API Key
type: secret-input
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: '4096'
type: text-input

View File

@ -334,7 +334,18 @@ class PromptTransform:
prompt = re.sub(r'<\|.*?\|>', '', prompt)
return [UserPromptMessage(content=prompt)]
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT and files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
return [prompt_message]
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#context#' in prompt_template.variable_keys:

View File

@ -75,7 +75,7 @@ GENERATOR_QA_PROMPT = (
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
"Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
"Answer MUST according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \

View File

@ -24,6 +24,9 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@ -229,11 +232,18 @@ class ProviderManager:
return None
provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name)
provider_schema = provider_instance.get_provider_schema()
return DefaultModelEntity(
model=default_model.model_name,
model_type=model_type,
provider=DefaultModelProviderEntity(**provider_instance.get_provider_schema().to_simple_provider().dict())
provider=DefaultModelProviderEntity(
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types
)
)
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
@ -465,15 +475,16 @@ class ProviderManager:
provider_credentials = {}
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
@ -517,15 +528,16 @@ class ProviderManager:
continue
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
@ -634,15 +646,16 @@ class ProviderManager:
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass

View File

@ -5,7 +5,6 @@ from Crypto.Cipher import PKCS1_OAEP, AES
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from core.helper.lru_cache import LRUCache
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -46,15 +45,7 @@ def encrypt(text, public_key):
return prefix_hybrid + encrypted_data
tenant_rsa_keys = LRUCache(capacity=1000)
def get_decrypt_decoding(tenant_id):
rsa_key = tenant_rsa_keys.get(tenant_id)
if rsa_key:
cipher_rsa = PKCS1_OAEP.new(rsa_key)
return rsa_key, cipher_rsa
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
@ -70,8 +61,6 @@ def get_decrypt_decoding(tenant_id):
rsa_key = RSA.import_key(private_key)
cipher_rsa = PKCS1_OAEP.new(rsa_key)
tenant_rsa_keys.put(tenant_id, rsa_key)
return rsa_key, cipher_rsa

View File

@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
from models.provider import ProviderType
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
DefaultModelResponse, ModelWithProviderEntityResponse
DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse
logger = logging.getLogger(__name__)
@ -45,7 +45,17 @@ class ModelProviderService:
continue
provider_response = ProviderResponse(
**provider_configuration.provider.dict(),
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
@ -53,7 +63,9 @@ class ModelProviderService:
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(
**provider_configuration.system_configuration.dict()
enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations
)
)
@ -369,7 +381,15 @@ class ModelProviderService:
)
return DefaultModelResponse(
**result.dict()
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:

View File

@ -27,7 +27,7 @@ def disable_segment_from_index_task(segment_id: str):
raise NotFound('Segment not found')
if segment.status != 'completed':
return
raise NotFound('Segment is not completed , disable action is not allowed.')
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)

View File

@ -29,7 +29,7 @@ def enable_segment_to_index_task(segment_id: str):
raise NotFound('Segment not found')
if segment.status != 'completed':
return
raise NotFound('Segment is not completed, enable action is not allowed.')
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)

View File

@ -39,13 +39,15 @@ def test_invoke_model(setup_openai_mock):
},
texts=[
"hello",
"world"
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert len(result.embeddings) == 4
assert result.usage.total_tokens == 2

View File

@ -22,7 +22,7 @@ def test_validate_credentials():
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
}
)
@ -31,7 +31,7 @@ def test_validate_credentials():
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
}
)
@ -43,7 +43,7 @@ def test_invoke_model():
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/completions',
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'completion'
},
prompt_messages=[
@ -74,7 +74,7 @@ def test_invoke_stream_model():
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
},
prompt_messages=[
@ -110,7 +110,7 @@ def test_invoke_chat_model_with_tools():
model='gpt-3.5-turbo',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/chat/completions',
'endpoint_url': 'https://api.openai.com/v1/',
'mode': 'chat'
},
prompt_messages=[
@ -165,7 +165,7 @@ def test_get_num_tokens():
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/chat/completions'
'endpoint_url': 'https://api.openai.com/v1/'
},
prompt_messages=[
SystemPromptMessage(

View File

@ -18,9 +18,8 @@ def test_validate_credentials():
model='text-embedding-ada-002',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
}
)
@ -29,9 +28,8 @@ def test_validate_credentials():
model='text-embedding-ada-002',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
}
)
@ -43,20 +41,21 @@ def test_invoke_model():
model='text-embedding-ada-002',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
},
texts=[
"hello",
"world"
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
assert len(result.embeddings) == 4
assert result.usage.total_tokens == 502
def test_get_num_tokens():
@ -67,8 +66,7 @@ def test_get_num_tokens():
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184,
'max_chunks': 32
'context_size': 8184
},
texts=[
"hello",

View File

@ -0,0 +1,117 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
SystemPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
def test_validate_credentials():
model = TogetherAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': 'invalid_key',
'mode': 'chat'
}
)
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
}
)
def test_invoke_model():
model = TogetherAILargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'completion'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = TogetherAILargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_get_num_tokens():
model = TogetherAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 21

View File

@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.4.3
image: langgenius/dify-api:0.4.4
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -130,7 +130,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.4.3
image: langgenius/dify-api:0.4.4
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -200,7 +200,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.4.3
image: langgenius/dify-web:0.4.4
restart: always
environment:
EDITION: SELF_HOSTED

View File

@ -128,7 +128,7 @@ const SegmentCard: FC<ISegmentCardProps> = ({
>
<Switch
size='md'
disabled={archived}
disabled={archived || detail.status !== 'completed'}
defaultValue={enabled}
onChange={async (val) => {
await onChangeSwitch?.(id, val)

View File

@ -279,7 +279,7 @@ const TextGeneration: FC<IMainProps> = ({
}
})
setAllTaskList(allTaskList)
setCurrGroupNum(0)
setControlSend(Date.now())
// clear run once task status
setControlStopResponding(Date.now())
@ -295,10 +295,7 @@ const TextGeneration: FC<IMainProps> = ({
// avoid add many task at the same time
if (needToAddNextGroupTask)
setCurrGroupNum(hadRunedTaskNum)
// console.group()
// console.log(`[#${taskId}]: ${isSuccess ? 'success' : 'fail'}.currGroupNum: ${getCurrGroupNum()}.hadRunedTaskNum: ${hadRunedTaskNum}, needToAddNextGroupTask: ${needToAddNextGroupTask}`)
// console.log([...allTasklistLatest.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).map(item => item.id), taskId].sort((a: any, b: any) => a - b).join(','))
// console.groupEnd()
const nextPendingTaskIds = needToAddNextGroupTask ? pendingTaskList.slice(0, GROUP_SIZE).map(item => item.id) : []
const newAllTaskList = allTasklistLatest.map((item) => {
if (item.id === taskId) {

View File

@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "0.4.3",
"version": "0.4.4",
"private": true,
"scripts": {
"dev": "next dev",