From 9ddbc1c0fb22dfb4264e47b0a86f05b9dc068e8d Mon Sep 17 00:00:00 2001 From: Varun Chawla <34209028+veeceey@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:54:26 -0800 Subject: [PATCH] fix: map all NodeType values to span kinds in Arize Phoenix tracing (#32059) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../arize_phoenix_trace.py | 29 ++++++++++----- .../core/ops/test_arize_phoenix_trace.py | 36 +++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a7b73e032e..452255f69e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -155,6 +155,26 @@ def wrap_span_metadata(metadata, **kwargs): return metadata +# Mapping from NodeType string values to OpenInference span kinds. +# NodeType values not listed here default to CHAIN. +_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = { + "llm": OpenInferenceSpanKindValues.LLM, + "knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER, + "tool": OpenInferenceSpanKindValues.TOOL, + "agent": OpenInferenceSpanKindValues.AGENT, +} + + +def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: + """Return the OpenInference span kind for a given workflow node type. + + Covers every ``NodeType`` enum value. Nodes that do not have a + specialised span kind (e.g. ``start``, ``end``, ``if-else``, + ``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``. + """ + return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -289,9 +309,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Determine the correct span kind based on node type - span_kind = OpenInferenceSpanKindValues.CHAIN + span_kind = _get_node_span_kind(node_execution.node_type) if node_execution.node_type == "llm": - span_kind = OpenInferenceSpanKindValues.LLM provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: @@ -306,12 +325,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) - elif node_execution.node_type == "dataset_retrieval": - span_kind = OpenInferenceSpanKindValues.RETRIEVER - elif node_execution.node_type == "tool": - span_kind = OpenInferenceSpanKindValues.TOOL - else: - span_kind = OpenInferenceSpanKindValues.CHAIN workflow_span_context = set_span_in_context(workflow_span) node_span = self.tracer.start_span( diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py new file mode 100644 index 0000000000..4f398ce66e --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -0,0 +1,36 @@ +from openinference.semconv.trace import OpenInferenceSpanKindValues + +from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind +from core.workflow.enums import NodeType + + +class TestGetNodeSpanKind: + """Tests for _get_node_span_kind helper.""" + + def test_all_node_types_are_mapped_correctly(self): + """Ensure every NodeType enum member is mapped to the correct span kind.""" + # Mappings for node types that have a specialised span kind. + special_mappings = { + NodeType.LLM: OpenInferenceSpanKindValues.LLM, + NodeType.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, + NodeType.TOOL: OpenInferenceSpanKindValues.TOOL, + NodeType.AGENT: OpenInferenceSpanKindValues.AGENT, + } + + # Test that every NodeType enum member is mapped to the correct span kind. + # Node types not in `special_mappings` should default to CHAIN. + for node_type in NodeType: + expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) + actual_span_kind = _get_node_span_kind(node_type) + assert actual_span_kind == expected_span_kind, ( + f"NodeType.{node_type.name} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." + ) + + def test_unknown_string_defaults_to_chain(self): + """An unrecognised node type string should still return CHAIN.""" + assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN + + def test_stale_dataset_retrieval_not_in_mapping(self): + """The old 'dataset_retrieval' string was never a valid NodeType value; + make sure it is not present in the mapping dictionary.""" + assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND