mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge remote-tracking branch 'upstream/main' into feat/rag-2
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
@ -120,7 +119,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
@ -353,8 +352,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
|
||||
@ -7,7 +7,6 @@ import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from opentelemetry import trace as trace_api
|
||||
@ -184,7 +183,7 @@ def generate_span_id() -> int:
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
@ -192,7 +191,7 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
|
||||
|
||||
def convert_string_to_id(string: Optional[str]) -> int:
|
||||
def convert_string_to_id(string: str | None) -> int:
|
||||
if not string:
|
||||
return generate_span_id()
|
||||
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
|
||||
@ -200,7 +199,7 @@ def convert_string_to_id(string: Optional[str]) -> int:
|
||||
return id
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
@ -209,7 +208,7 @@ def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
return convert_string_to_id(combined_key)
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
|
||||
def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
@ -10,12 +9,12 @@ class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
class GenAISpanKind(StrEnum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
@ -91,14 +92,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def string_to_trace_id128(string: Optional[str]) -> int:
|
||||
def string_to_trace_id128(string: str | None) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
|
||||
@ -283,7 +284,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
@ -307,7 +308,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: Optional[str] = None
|
||||
message_data: Optional[Any] = None
|
||||
inputs: Optional[Union[str, dict[str, Any], list]] = None
|
||||
outputs: Optional[Union[str, dict[str, Any], list]] = None
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
message_id: str | None = None
|
||||
message_data: Any | None = None
|
||||
inputs: Union[str, dict[str, Any], list] | None = None
|
||||
outputs: Union[str, dict[str, Any], list] | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
metadata: dict[str, Any]
|
||||
trace_id: Optional[str] = None
|
||||
trace_id: str | None = None
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
@ -35,9 +35,9 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_data: Any
|
||||
conversation_id: Optional[str] = None
|
||||
workflow_app_log_id: Optional[str] = None
|
||||
workflow_data: Any = None
|
||||
conversation_id: str | None = None
|
||||
workflow_app_log_id: str | None = None
|
||||
workflow_id: str
|
||||
tenant_id: str
|
||||
workflow_run_id: str
|
||||
@ -46,7 +46,7 @@ class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_run_inputs: Mapping[str, Any]
|
||||
workflow_run_outputs: Mapping[str, Any]
|
||||
workflow_run_version: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
total_tokens: int
|
||||
file_list: list[str]
|
||||
query: str
|
||||
@ -58,9 +58,9 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
message_tokens: int
|
||||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: Optional[str] = None
|
||||
file_list: Optional[Union[str, dict[str, Any], list]] = None
|
||||
message_file_data: Optional[Any] = None
|
||||
error: str | None = None
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
|
||||
|
||||
@ -73,23 +73,23 @@ class ModerationTraceInfo(BaseTraceInfo):
|
||||
|
||||
class SuggestedQuestionTraceInfo(BaseTraceInfo):
|
||||
total_tokens: int
|
||||
status: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
from_account_id: Optional[str] = None
|
||||
agent_based: Optional[bool] = None
|
||||
from_source: Optional[str] = None
|
||||
model_provider: Optional[str] = None
|
||||
model_id: Optional[str] = None
|
||||
status: str | None = None
|
||||
error: str | None = None
|
||||
from_account_id: str | None = None
|
||||
agent_based: bool | None = None
|
||||
from_source: str | None = None
|
||||
model_provider: str | None = None
|
||||
model_id: str | None = None
|
||||
suggested_question: list[str]
|
||||
level: str
|
||||
status_message: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
status_message: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
||||
documents: Any
|
||||
documents: Any = None
|
||||
|
||||
|
||||
class ToolTraceInfo(BaseTraceInfo):
|
||||
@ -97,23 +97,23 @@ class ToolTraceInfo(BaseTraceInfo):
|
||||
tool_inputs: dict[str, Any]
|
||||
tool_outputs: str
|
||||
metadata: dict[str, Any]
|
||||
message_file_data: Any
|
||||
error: Optional[str] = None
|
||||
message_file_data: Any = None
|
||||
error: str | None = None
|
||||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list]
|
||||
file_url: Union[str, None, list] = None
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
conversation_id: Optional[str] = None
|
||||
conversation_id: str | None = None
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
app_id: str
|
||||
trace_info_type: str
|
||||
trace_info: Any
|
||||
trace_info: Any = None
|
||||
|
||||
|
||||
trace_info_info_map = {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -52,50 +52,50 @@ class LangfuseTrace(BaseModel):
|
||||
Langfuse trace model
|
||||
"""
|
||||
|
||||
id: Optional[str] = Field(
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
|
||||
"or when creating a distributed trace. Traces are upserted on id.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the trace. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
input: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
default=None, description="The input of the trace. Can be any JSON object."
|
||||
)
|
||||
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
output: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
default=None, description="The output of the trace. Can be any JSON object."
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
user_id: Optional[str] = Field(
|
||||
user_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
release: Optional[str] = Field(
|
||||
release: str | None = Field(
|
||||
default=None,
|
||||
description="The release identifier of the current deployment. Used to understand how changes of different "
|
||||
"deployments affect metrics. Useful in debugging.",
|
||||
)
|
||||
tags: Optional[list[str]] = Field(
|
||||
tags: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
|
||||
"API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
|
||||
)
|
||||
public: Optional[bool] = Field(
|
||||
public: bool | None = Field(
|
||||
default=None,
|
||||
description="You can make a trace public to share it via a public link. This allows others to view the trace "
|
||||
"without needing to log in or be members of your Langfuse project.",
|
||||
@ -113,61 +113,61 @@ class LangfuseSpan(BaseModel):
|
||||
Langfuse span model
|
||||
"""
|
||||
|
||||
id: Optional[str] = Field(
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
trace_id: Optional[str] = Field(
|
||||
trace_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace the span belongs to. Used to link spans to traces.",
|
||||
)
|
||||
user_id: Optional[str] = Field(
|
||||
user_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
start_time: Optional[datetime | str] = Field(
|
||||
start_time: datetime | str | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the span started, defaults to the current time.",
|
||||
)
|
||||
end_time: Optional[datetime | str] = Field(
|
||||
end_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the span ended. Automatically set by span.end().",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the span. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
level: Optional[str] = Field(
|
||||
level: str | None = Field(
|
||||
default=None,
|
||||
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
||||
"traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: Optional[str] = Field(
|
||||
status_message: str | None = Field(
|
||||
default=None,
|
||||
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||
input: Union[str, Mapping[str, Any], list, None] | None = Field(
|
||||
default=None, description="The input of the span. Can be any JSON object."
|
||||
)
|
||||
output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||
output: Union[str, Mapping[str, Any], list, None] | None = Field(
|
||||
default=None, description="The output of the span. Can be any JSON object."
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the span type. Used to understand how changes to the span type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
parent_observation_id: Optional[str] = Field(
|
||||
parent_observation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the observation the span belongs to. Used to link spans to observations.",
|
||||
)
|
||||
@ -188,15 +188,15 @@ class UnitEnum(StrEnum):
|
||||
|
||||
|
||||
class GenerationUsage(BaseModel):
|
||||
promptTokens: Optional[int] = None
|
||||
completionTokens: Optional[int] = None
|
||||
total: Optional[int] = None
|
||||
input: Optional[int] = None
|
||||
output: Optional[int] = None
|
||||
unit: Optional[UnitEnum] = None
|
||||
inputCost: Optional[float] = None
|
||||
outputCost: Optional[float] = None
|
||||
totalCost: Optional[float] = None
|
||||
promptTokens: int | None = None
|
||||
completionTokens: int | None = None
|
||||
total: int | None = None
|
||||
input: int | None = None
|
||||
output: int | None = None
|
||||
unit: UnitEnum | None = None
|
||||
inputCost: float | None = None
|
||||
outputCost: float | None = None
|
||||
totalCost: float | None = None
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
@ -206,69 +206,69 @@ class GenerationUsage(BaseModel):
|
||||
|
||||
|
||||
class LangfuseGeneration(BaseModel):
|
||||
id: Optional[str] = Field(
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the generation can be set, defaults to random id.",
|
||||
)
|
||||
trace_id: Optional[str] = Field(
|
||||
trace_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace the generation belongs to. Used to link generations to traces.",
|
||||
)
|
||||
parent_observation_id: Optional[str] = Field(
|
||||
parent_observation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the observation the generation belongs to. Used to link generations to observations.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
start_time: Optional[datetime | str] = Field(
|
||||
start_time: datetime | str | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the generation started, defaults to the current time.",
|
||||
)
|
||||
completion_start_time: Optional[datetime | str] = Field(
|
||||
completion_start_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
||||
"down into time until completion started and completion duration.",
|
||||
)
|
||||
end_time: Optional[datetime | str] = Field(
|
||||
end_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the generation ended. Automatically set by generation.end().",
|
||||
)
|
||||
model: Optional[str] = Field(default=None, description="The name of the model used for the generation.")
|
||||
model_parameters: Optional[dict[str, Any]] = Field(
|
||||
model: str | None = Field(default=None, description="The name of the model used for the generation.")
|
||||
model_parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="The parameters of the model used for the generation; can be any key-value pairs.",
|
||||
)
|
||||
input: Optional[Any] = Field(
|
||||
input: Any | None = Field(
|
||||
default=None,
|
||||
description="The prompt used for the generation. Can be any string or JSON object.",
|
||||
)
|
||||
output: Optional[Any] = Field(
|
||||
output: Any | None = Field(
|
||||
default=None,
|
||||
description="The completion generated by the model. Can be any string or JSON object.",
|
||||
)
|
||||
usage: Optional[GenerationUsage] = Field(
|
||||
usage: GenerationUsage | None = Field(
|
||||
default=None,
|
||||
description="The usage object supports the OpenAi structure with tokens and a more generic version with "
|
||||
"detailed costs and units.",
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
|
||||
"updated via the API.",
|
||||
)
|
||||
level: Optional[LevelEnum] = Field(
|
||||
level: LevelEnum | None = Field(
|
||||
default=None,
|
||||
description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
|
||||
"of traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: Optional[str] = Field(
|
||||
status_message: str | None = Field(
|
||||
default=None,
|
||||
description="The status message of the generation. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the generation type. Used to understand how changes to the span type affect "
|
||||
"metrics. Useful in debugging.",
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -145,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
@ -164,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
model_provider = process_data.get("model_provider", None)
|
||||
model_name = process_data.get("model_name", None)
|
||||
if model_provider is not None and model_name is not None:
|
||||
@ -242,7 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
@ -399,7 +398,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.add_span(langfuse_span_data=name_generation_span_data)
|
||||
|
||||
def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
|
||||
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
|
||||
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
try:
|
||||
self.langfuse_client.trace(**format_trace_data)
|
||||
@ -407,7 +406,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
|
||||
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
|
||||
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
try:
|
||||
self.langfuse_client.span(**format_span_data)
|
||||
@ -415,12 +414,12 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
||||
|
||||
def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
|
||||
def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None):
|
||||
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
|
||||
span.end(**format_span_data)
|
||||
|
||||
def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None):
|
||||
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
)
|
||||
@ -430,7 +429,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
||||
|
||||
def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None):
|
||||
def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -20,36 +20,36 @@ class LangSmithRunType(StrEnum):
|
||||
|
||||
|
||||
class LangSmithTokenUsage(BaseModel):
|
||||
input_tokens: Optional[int] = None
|
||||
output_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class LangSmithMultiModel(BaseModel):
|
||||
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
name: Optional[str] = Field(..., description="Name of the run")
|
||||
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
||||
name: str | None = Field(..., description="Name of the run")
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run")
|
||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run")
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
id: Optional[str] = Field(None, description="ID of the run")
|
||||
session_id: Optional[str] = Field(None, description="Session ID associated with the run")
|
||||
session_name: Optional[str] = Field(None, description="Session name associated with the run")
|
||||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
start_time: datetime | str | None = Field(None, description="Start time of the run")
|
||||
end_time: datetime | str | None = Field(None, description="End time of the run")
|
||||
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
|
||||
error: str | None = Field(None, description="Error message of the run")
|
||||
serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run")
|
||||
parent_run_id: str | None = Field(None, description="Parent run ID")
|
||||
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
|
||||
tags: list[str] | None = Field(None, description="Tags associated with the run")
|
||||
trace_id: str | None = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: str | None = Field(None, description="Dotted order of the run")
|
||||
id: str | None = Field(None, description="ID of the run")
|
||||
session_id: str | None = Field(None, description="Session ID associated with the run")
|
||||
session_name: str | None = Field(None, description="Session name associated with the run")
|
||||
reference_example_id: str | None = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
|
||||
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
@ -128,15 +128,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
|
||||
class LangSmithRunUpdateModel(BaseModel):
|
||||
run_id: str = Field(..., description="ID of the run")
|
||||
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
|
||||
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
trace_id: str | None = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: str | None = Field(None, description="Dotted order of the run")
|
||||
parent_run_id: str | None = Field(None, description="Parent run ID")
|
||||
end_time: datetime | str | None = Field(None, description="End time of the run")
|
||||
error: str | None = Field(None, description="Error message of the run")
|
||||
inputs: dict[str, Any] | None = Field(None, description="Inputs of the run")
|
||||
outputs: dict[str, Any] | None = Field(None, description="Outputs of the run")
|
||||
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
|
||||
tags: list[str] | None = Field(None, description="Tags associated with the run")
|
||||
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
|
||||
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
|
||||
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
@ -166,13 +166,13 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
@ -187,7 +187,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
@ -246,7 +246,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
metadata = trace_info.metadata
|
||||
@ -259,7 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
@ -46,7 +46,7 @@ def wrap_metadata(metadata, **kwargs):
|
||||
return metadata
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
||||
def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
@ -181,13 +181,13 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
@ -201,7 +201,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
@ -263,7 +263,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
@ -281,7 +281,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -42,7 +43,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
@ -123,7 +124,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
|
||||
class OpsTraceManager:
|
||||
@ -220,7 +221,7 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
trace_config_data: TraceAppConfig | None = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
@ -244,7 +245,7 @@ class OpsTraceManager:
|
||||
@classmethod
|
||||
def get_ops_trace_instance(
|
||||
cls,
|
||||
app_id: Optional[Union[UUID, str]] = None,
|
||||
app_id: Union[UUID, str] | None = None,
|
||||
):
|
||||
"""
|
||||
Get ops trace through model config
|
||||
@ -257,7 +258,7 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
return None
|
||||
@ -331,7 +332,7 @@ class OpsTraceManager:
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app_config:
|
||||
raise ValueError("App not found")
|
||||
app_config.tracing = json.dumps(
|
||||
@ -349,7 +350,7 @@ class OpsTraceManager:
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
@ -825,7 +826,7 @@ class TraceTask:
|
||||
return generate_name_trace_info
|
||||
|
||||
|
||||
trace_manager_timer: Optional[threading.Timer] = None
|
||||
trace_manager_timer: threading.Timer | None = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
|
||||
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -49,9 +49,7 @@ def replace_text_with_content(data):
|
||||
return data
|
||||
|
||||
|
||||
def generate_dotted_order(
|
||||
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
|
||||
) -> str:
|
||||
def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_dotted_order: str | None = None) -> str:
|
||||
"""
|
||||
generate dotted_order for langsmith
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -8,24 +8,24 @@ from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: Optional[int] = None
|
||||
output_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace")
|
||||
attributes: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
exception: Optional[str] = Field(None, description="Exception message of the trace")
|
||||
exception: str | None = Field(None, description="Exception message of the trace")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
@ -168,13 +168,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
@ -189,7 +189,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
@ -222,7 +222,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
attributes = trace_info.metadata
|
||||
@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
@ -423,7 +423,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
|
||||
Reference in New Issue
Block a user