diff --git a/app/assets/seeder.py b/app/assets/seeder.py index c6107bc55..d65ffb23f 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -77,7 +77,7 @@ class _AssetSeeder: """ def __init__(self) -> None: - self._lock = threading.Lock() + self._lock = threading.RLock() self._state = State.IDLE self._progress: Progress | None = None self._last_progress: Progress | None = None @@ -637,12 +637,12 @@ class _AssetSeeder: with self._lock: self._reset_to_idle() pending = self._pending_enrich - self._pending_enrich = None - if pending is not None: - self.start_enrich( - roots=pending["roots"], - compute_hashes=pending["compute_hashes"], - ) + if pending is not None: + self._pending_enrich = None + self.start_enrich( + roots=pending["roots"], + compute_hashes=pending["compute_hashes"], + ) def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: """Run phase 1: fast scan to create stub records. diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py index db3795e48..6aed6d6f3 100644 --- a/tests-unit/seeder_test/test_seeder.py +++ b/tests-unit/seeder_test/test_seeder.py @@ -1,6 +1,7 @@ """Unit tests for the _AssetSeeder background scanning class.""" import threading +import time from unittest.mock import patch import pytest @@ -771,6 +772,188 @@ class TestSeederStopRestart: assert collected_roots[1] == ("input",) +class TestEnqueueEnrichHandoff: + """Test that the drain of _pending_enrich is atomic with start_enrich.""" + + def test_pending_enrich_runs_after_scan_completes( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + """A queued enrich request runs automatically when a scan finishes.""" + enrich_roots_seen: list[tuple] = [] + original_start = fresh_seeder.start + + def tracking_start(*args, **kwargs): + phase = kwargs.get("phase") + roots = kwargs.get("roots", args[0] if args else None) + result = original_start(*args, **kwargs) + if phase == ScanPhase.ENRICH and result: + enrich_roots_seen.append(roots) + return result + + fresh_seeder.start = tracking_start + + # Start a fast scan, then enqueue an enrich while it's running + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert reached.wait(timeout=2.0) + + queued = fresh_seeder.enqueue_enrich( + roots=("input",), compute_hashes=True + ) + assert queued is False # queued, not started immediately + + barrier.set() + + # Wait for the original scan + the auto-started enrich scan + deadline = time.monotonic() + 5.0 + while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline: + time.sleep(0.05) + + assert enrich_roots_seen == [("input",)] + + def test_enqueue_enrich_during_drain_does_not_lose_work( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + """enqueue_enrich called concurrently with drain cannot drop work. + + Simulates the race: another thread calls enqueue_enrich right as the + scan thread is draining _pending_enrich. The enqueue must either be + picked up by the draining scan or successfully start its own scan. + """ + barrier = threading.Event() + reached = threading.Event() + enrich_started = threading.Event() + + enrich_call_count = 0 + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + # Track how many times start_enrich actually fires + real_start_enrich = fresh_seeder.start_enrich + enrich_roots_seen: list[tuple] = [] + + def tracking_start_enrich(**kwargs): + nonlocal enrich_call_count + enrich_call_count += 1 + enrich_roots_seen.append(kwargs.get("roots")) + result = real_start_enrich(**kwargs) + if result: + enrich_started.set() + return result + + fresh_seeder.start_enrich = tracking_start_enrich + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + # Start a scan + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert reached.wait(timeout=2.0) + + # Queue an enrich while scan is running + fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False) + + # Let scan finish — drain will fire start_enrich atomically + barrier.set() + + # Wait for drain to complete and the enrich scan to start + assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain" + assert ("output",) in enrich_roots_seen + + def test_concurrent_enqueue_during_drain_not_lost( + self, fresh_seeder: _AssetSeeder, + ): + """A second enqueue_enrich arriving while drain is in progress is not lost. + + Because the drain now holds _lock through the start_enrich call, + a concurrent enqueue_enrich will block until start_enrich has + transitioned state to RUNNING, then the enqueue will queue its + payload as _pending_enrich for the *next* drain. + """ + scan_barrier = threading.Event() + scan_reached = threading.Event() + enrich_barrier = threading.Event() + enrich_reached = threading.Event() + + collect_call = 0 + + def gated_collect(*args): + nonlocal collect_call + collect_call += 1 + if collect_call == 1: + # First call: the initial fast scan + scan_reached.set() + scan_barrier.wait(timeout=5.0) + return [] + + enrich_call = 0 + + def gated_get_unenriched(*args, **kwargs): + nonlocal enrich_call + enrich_call += 1 + if enrich_call == 1: + # First enrich batch: signal and block + enrich_reached.set() + enrich_barrier.wait(timeout=5.0) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + # 1. Start fast scan + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert scan_reached.wait(timeout=2.0) + + # 2. Queue enrich while fast scan is running + queued = fresh_seeder.enqueue_enrich( + roots=("input",), compute_hashes=False + ) + assert queued is False + + # 3. Let the fast scan finish — drain will start the enrich scan + scan_barrier.set() + + # 4. Wait until the drained enrich scan is running + assert enrich_reached.wait(timeout=5.0) + + # 5. Now enqueue another enrich while the drained scan is running + queued2 = fresh_seeder.enqueue_enrich( + roots=("output",), compute_hashes=True + ) + assert queued2 is False # should be queued, not started + + # Verify _pending_enrich was set (the second enqueue was captured) + with fresh_seeder._lock: + assert fresh_seeder._pending_enrich is not None + assert "output" in fresh_seeder._pending_enrich["roots"] + + # Let the enrich scan finish + enrich_barrier.set() + + deadline = time.monotonic() + 5.0 + while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline: + time.sleep(0.05) + + def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: return UnenrichedReferenceRow( reference_id=ref_id, asset_id=asset_id,