Fix race in enqueue_enrich drain: make pending-to-start handoff atomic

Change _lock from Lock to RLock and move the start_enrich call inside the
lock-held block so that enqueue_enrich cannot interleave between clearing
_pending_enrich and starting the enrichment scan. This prevents a concurrent
enqueue_enrich from stealing the IDLE slot and causing the drained payload
to be silently dropped.

Add tests covering:
- pending enrich runs after scan completes
- enqueue during drain does not lose work
- concurrent enqueue during drain is queued for the next cycle

Amp-Thread-ID: https://ampcode.com/threads/T-019cfe02-5710-7506-ae80-34bf16c0171a
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-17 15:59:01 -07:00
parent b9286572d3
commit f9d85fa176
2 changed files with 190 additions and 7 deletions

View File

@ -77,7 +77,7 @@ class _AssetSeeder:
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._lock = threading.RLock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
@ -637,12 +637,12 @@ class _AssetSeeder:
with self._lock:
self._reset_to_idle()
pending = self._pending_enrich
self._pending_enrich = None
if pending is not None:
self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
)
if pending is not None:
self._pending_enrich = None
self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
)
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.

View File

@ -1,6 +1,7 @@
"""Unit tests for the _AssetSeeder background scanning class."""
import threading
import time
from unittest.mock import patch
import pytest
@ -771,6 +772,188 @@ class TestSeederStopRestart:
assert collected_roots[1] == ("input",)
class TestEnqueueEnrichHandoff:
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
def test_pending_enrich_runs_after_scan_completes(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""A queued enrich request runs automatically when a scan finishes."""
enrich_roots_seen: list[tuple] = []
original_start = fresh_seeder.start
def tracking_start(*args, **kwargs):
phase = kwargs.get("phase")
roots = kwargs.get("roots", args[0] if args else None)
result = original_start(*args, **kwargs)
if phase == ScanPhase.ENRICH and result:
enrich_roots_seen.append(roots)
return result
fresh_seeder.start = tracking_start
# Start a fast scan, then enqueue an enrich while it's running
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=True
)
assert queued is False # queued, not started immediately
barrier.set()
# Wait for the original scan + the auto-started enrich scan
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
assert enrich_roots_seen == [("input",)]
def test_enqueue_enrich_during_drain_does_not_lose_work(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""enqueue_enrich called concurrently with drain cannot drop work.
Simulates the race: another thread calls enqueue_enrich right as the
scan thread is draining _pending_enrich. The enqueue must either be
picked up by the draining scan or successfully start its own scan.
"""
barrier = threading.Event()
reached = threading.Event()
enrich_started = threading.Event()
enrich_call_count = 0
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
# Track how many times start_enrich actually fires
real_start_enrich = fresh_seeder.start_enrich
enrich_roots_seen: list[tuple] = []
def tracking_start_enrich(**kwargs):
nonlocal enrich_call_count
enrich_call_count += 1
enrich_roots_seen.append(kwargs.get("roots"))
result = real_start_enrich(**kwargs)
if result:
enrich_started.set()
return result
fresh_seeder.start_enrich = tracking_start_enrich
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
# Start a scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
# Queue an enrich while scan is running
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
# Let scan finish — drain will fire start_enrich atomically
barrier.set()
# Wait for drain to complete and the enrich scan to start
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
assert ("output",) in enrich_roots_seen
def test_concurrent_enqueue_during_drain_not_lost(
self, fresh_seeder: _AssetSeeder,
):
"""A second enqueue_enrich arriving while drain is in progress is not lost.
Because the drain now holds _lock through the start_enrich call,
a concurrent enqueue_enrich will block until start_enrich has
transitioned state to RUNNING, then the enqueue will queue its
payload as _pending_enrich for the *next* drain.
"""
scan_barrier = threading.Event()
scan_reached = threading.Event()
enrich_barrier = threading.Event()
enrich_reached = threading.Event()
collect_call = 0
def gated_collect(*args):
nonlocal collect_call
collect_call += 1
if collect_call == 1:
# First call: the initial fast scan
scan_reached.set()
scan_barrier.wait(timeout=5.0)
return []
enrich_call = 0
def gated_get_unenriched(*args, **kwargs):
nonlocal enrich_call
enrich_call += 1
if enrich_call == 1:
# First enrich batch: signal and block
enrich_reached.set()
enrich_barrier.wait(timeout=5.0)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
# 1. Start fast scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert scan_reached.wait(timeout=2.0)
# 2. Queue enrich while fast scan is running
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=False
)
assert queued is False
# 3. Let the fast scan finish — drain will start the enrich scan
scan_barrier.set()
# 4. Wait until the drained enrich scan is running
assert enrich_reached.wait(timeout=5.0)
# 5. Now enqueue another enrich while the drained scan is running
queued2 = fresh_seeder.enqueue_enrich(
roots=("output",), compute_hashes=True
)
assert queued2 is False # should be queued, not started
# Verify _pending_enrich was set (the second enqueue was captured)
with fresh_seeder._lock:
assert fresh_seeder._pending_enrich is not None
assert "output" in fresh_seeder._pending_enrich["roots"]
# Let the enrich scan finish
enrich_barrier.set()
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,