feat: service api add llm usage (#2051)

This commit is contained in:
takatost
2024-01-17 22:39:47 +08:00
committed by GitHub
parent 1d91535ba6
commit 1a6ad05a23
15 changed files with 152 additions and 187 deletions

View File

@ -5,16 +5,18 @@ from typing import Generator, Optional, Union, cast
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import ApplicationGenerateEntity
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
PromptMessage, PromptMessageContentType, PromptMessageRole,
TextPromptMessageContent)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
@ -135,6 +137,8 @@ class GenerateTaskPipeline:
completion_tokens
)
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
@ -145,12 +149,13 @@ class GenerateTaskPipeline:
)
# Save message
self._save_message(event.llm_result)
self._save_message(self._task_state.llm_result)
response = {
'event': 'message',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'metadata': {},
@ -161,7 +166,7 @@ class GenerateTaskPipeline:
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
response['metadata'] = self._get_response_metadata()
return response
else:
@ -176,7 +181,9 @@ class GenerateTaskPipeline:
event = message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
data = self._error_to_stream_response_data(self._handle_error(event))
yield self._yield_response(data)
break
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
@ -213,6 +220,8 @@ class GenerateTaskPipeline:
completion_tokens
)
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
@ -244,13 +253,14 @@ class GenerateTaskPipeline:
'event': 'message_end',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
'message_id': self._message.id,
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
response['metadata'] = self._get_response_metadata()
yield self._yield_response(response)
elif isinstance(event, QueueRetrieverResourcesEvent):
@ -410,6 +420,86 @@ class GenerateTaskPipeline:
else:
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
def _error_to_stream_response_data(self, e: Exception) -> dict:
"""
Error to stream response.
:param e: exception
:return:
"""
if isinstance(e, ValueError):
data = {
'code': 'invalid_param',
'message': str(e),
'status': 400
}
elif isinstance(e, ProviderTokenNotInitError):
data = {
'code': 'provider_not_initialize',
'message': e.description,
'status': 400
}
elif isinstance(e, QuotaExceededError):
data = {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
}
elif isinstance(e, ModelCurrentlyNotSupportError):
data = {
'code': 'model_currently_not_support',
'message': e.description,
'status': 400
}
elif isinstance(e, InvokeError):
data = {
'code': 'completion_request_error',
'message': e.description,
'status': 400
}
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
}
return {
'event': 'error',
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
**data
}
def _get_response_metadata(self) -> dict:
"""
Get response metadata by invoke from.
:return:
"""
metadata = {}
# show_retrieve_source
if 'retriever_resources' in self._task_state.metadata:
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
else:
metadata['retriever_resources'] = []
for resource in self._task_state.metadata['retriever_resources']:
metadata['retriever_resources'].append({
'segment_id': resource['segment_id'],
'position': resource['position'],
'document_name': resource['document_name'],
'score': resource['score'],
'content': resource['content'],
})
# show usage
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['usage'] = self._task_state.metadata['usage']
return metadata
def _yield_response(self, response: dict) -> str:
"""
Yield response.

View File

@ -151,6 +151,8 @@ def jsonable_encoder(
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, Decimal):
return format(obj, 'f')
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())