Skip to content

Commit 4976da1

Browse files
authored
Fix crash recovery (#29)
* Fix pending batch being skipped after reorg detection Move the pending batch check to BEFORE fetching a new batch from the stream. Previously, after a reorg was detected: 1. The data batch was stored as _pending_batch 2. A reorg batch was returned 3. On next __next__() call, a NEW batch was fetched BEFORE checking _pending_batch 4. This caused the pending batch to be lost or returned out of order Now the pending batch check happens first, ensuring proper ordering: reorg_batch -> pending_data_batch -> next_batch * Add crash recovery via _rewind_to_watermark On stream start, automatically clean up any data written after the last checkpoint watermark. This handles crash scenarios where data was written but the checkpoint was not saved. The _rewind_to_watermark method: 1. Gets the last watermark from state store 2. Creates invalidation ranges for blocks after the watermark 3. Calls _handle_reorg to delete uncommitted data 4. Falls back gracefully if loader does not support deletion Called automatically at start of load_stream_continuous once per table. * Simplify _rewind_to_watermark to require table_name parameter - Make table_name required instead of Optional (matches actual usage) - Remove unnecessary list scaffolding since callers always provide single table - Include hash/prev_hash in invalidation range for potential validation - Remove test for table_name=None case
1 parent b07241f commit 4976da1

File tree

3 files changed

+196
-8
lines changed

3 files changed

+196
-8
lines changed

src/amp/loaders/base.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
7878
else:
7979
self.state_store = NullStreamStateStore()
8080

81+
# Track tables that have undergone crash recovery
82+
self._crash_recovery_done: set[str] = set()
83+
8184
@property
8285
def is_connected(self) -> bool:
8386
"""Check if the loader is connected to the target system."""
8487
return self._is_connected
8588

89+
@property
90+
def loader_type(self) -> str:
91+
"""Get the loader type identifier (e.g., 'postgresql', 'redis')."""
92+
return self.__class__.__name__.replace('Loader', '').lower()
93+
8694
def _parse_config(self, config: Dict[str, Any]) -> TConfig:
8795
"""
8896
Parse configuration into loader-specific format.
@@ -447,11 +455,21 @@ def load_stream_continuous(
447455
if not self._is_connected:
448456
self.connect()
449457

458+
connection_name = kwargs.get('connection_name')
459+
if connection_name is None:
460+
connection_name = self.loader_type
461+
462+
if table_name not in self._crash_recovery_done:
463+
self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})')
464+
self._rewind_to_watermark(table_name, connection_name)
465+
self._crash_recovery_done.add(table_name)
466+
else:
467+
self.logger.info(f'Crash recovery already done for table {table_name}')
468+
450469
rows_loaded = 0
451470
start_time = time.time()
452471
batch_count = 0
453472
reorg_count = 0
454-
connection_name = kwargs.get('connection_name', 'unknown')
455473
worker_id = kwargs.get('worker_id', 0)
456474

457475
try:
@@ -785,6 +803,68 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str,
785803
'Streaming with reorg detection requires implementing this method.'
786804
)
787805

806+
def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] = None) -> None:
807+
"""
808+
Reset state and data to the last checkpointed watermark.
809+
810+
Removes any data written after the last completed watermark,
811+
ensuring resumable streams start from a consistent state.
812+
813+
This handles crash recovery by removing uncommitted data from
814+
incomplete microbatches between watermarks.
815+
816+
Args:
817+
table_name: Table to clean up.
818+
connection_name: Connection identifier. If None, uses default.
819+
"""
820+
if not self.state_enabled:
821+
self.logger.debug('State tracking disabled, skipping crash recovery')
822+
return
823+
824+
if connection_name is None:
825+
connection_name = self.loader_type
826+
827+
resume_pos = self.state_store.get_resume_position(connection_name, table_name)
828+
if not resume_pos:
829+
self.logger.debug(f'No watermark found for {table_name}, skipping crash recovery')
830+
return
831+
832+
for range_obj in resume_pos.ranges:
833+
from_block = range_obj.end + 1
834+
835+
self.logger.info(
836+
f'Crash recovery: Cleaning up {table_name} data for {range_obj.network} from block {from_block} onwards'
837+
)
838+
839+
invalidation_ranges = [
840+
BlockRange(
841+
network=range_obj.network,
842+
start=from_block,
843+
end=from_block,
844+
hash=range_obj.hash,
845+
prev_hash=range_obj.prev_hash,
846+
)
847+
]
848+
849+
try:
850+
self._handle_reorg(invalidation_ranges, table_name, connection_name)
851+
self.logger.info(f'Crash recovery completed for {range_obj.network} in {table_name}')
852+
853+
except NotImplementedError:
854+
invalidated = self.state_store.invalidate_from_block(
855+
connection_name, table_name, range_obj.network, from_block
856+
)
857+
858+
if invalidated:
859+
self.logger.warning(
860+
f'Crash recovery: Cleared {len(invalidated)} batches from state '
861+
f'for {range_obj.network} but cannot delete data from {table_name}. '
862+
f'{self.__class__.__name__} does not support data deletion. '
863+
f'Duplicates may occur on resume.'
864+
)
865+
else:
866+
self.logger.debug(f'No uncommitted batches found for {range_obj.network}')
867+
788868
def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch:
789869
"""
790870
Add metadata columns for streaming data with compact batch identification.

src/amp/streaming/reorg.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def __next__(self) -> ResponseBatch:
4646
KeyboardInterrupt: When user cancels the stream
4747
"""
4848
try:
49+
# Check if we have a pending batch from a previous reorg detection
50+
if hasattr(self, '_pending_batch'):
51+
pending = self._pending_batch
52+
delattr(self, '_pending_batch')
53+
return pending
54+
4955
# Get next batch from underlying stream
5056
batch = next(self.stream_iterator)
5157

@@ -63,13 +69,6 @@ def __next__(self) -> ResponseBatch:
6369
self._pending_batch = batch
6470
return ResponseBatch.reorg_batch(invalidation_ranges)
6571

66-
# Check if we have a pending batch from a previous reorg detection
67-
# REVIEW: I think we should remove this
68-
if hasattr(self, '_pending_batch'):
69-
pending = self._pending_batch
70-
delattr(self, '_pending_batch')
71-
return pending
72-
7372
# Normal case - just return the data batch
7473
return batch
7574

tests/unit/test_crash_recovery.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Unit tests for crash recovery via _rewind_to_watermark() method.
3+
4+
These tests verify the crash recovery logic works correctly in isolation.
5+
"""
6+
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from src.amp.loaders.base import LoadResult
12+
from src.amp.streaming.types import BlockRange, ResumeWatermark
13+
from tests.fixtures.mock_clients import MockDataLoader
14+
15+
16+
@pytest.fixture
17+
def mock_loader() -> MockDataLoader:
18+
"""Create a mock loader with state store"""
19+
loader = MockDataLoader({'test': 'config'})
20+
loader.connect()
21+
22+
loader.state_store = Mock()
23+
loader.state_enabled = True
24+
25+
return loader
26+
27+
28+
@pytest.mark.unit
29+
class TestCrashRecovery:
30+
"""Test _rewind_to_watermark() crash recovery method"""
31+
32+
def test_rewind_with_no_state(self, mock_loader):
33+
"""Should return early if state_enabled=False"""
34+
mock_loader.state_enabled = False
35+
36+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
37+
38+
mock_loader.state_store.get_resume_position.assert_not_called()
39+
40+
def test_rewind_with_no_watermark(self, mock_loader):
41+
"""Should return early if no watermark exists"""
42+
mock_loader.state_store.get_resume_position = Mock(return_value=None)
43+
44+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
45+
46+
mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table')
47+
48+
def test_rewind_calls_handle_reorg(self, mock_loader):
49+
"""Should call _handle_reorg with correct invalidation ranges"""
50+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
51+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
52+
mock_loader._handle_reorg = Mock()
53+
54+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
55+
56+
mock_loader._handle_reorg.assert_called_once()
57+
call_args = mock_loader._handle_reorg.call_args
58+
invalidation_ranges = call_args[0][0]
59+
assert len(invalidation_ranges) == 1
60+
assert invalidation_ranges[0].network == 'ethereum'
61+
assert invalidation_ranges[0].start == 1011
62+
assert call_args[0][1] == 'test_table'
63+
assert call_args[0][2] == 'test_conn'
64+
65+
def test_rewind_handles_not_implemented(self, mock_loader):
66+
"""Should gracefully handle loaders without _handle_reorg"""
67+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
68+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
69+
mock_loader._handle_reorg = Mock(side_effect=NotImplementedError())
70+
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])
71+
72+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
73+
74+
mock_loader.state_store.invalidate_from_block.assert_called_once_with(
75+
'test_conn', 'test_table', 'ethereum', 1011
76+
)
77+
78+
def test_rewind_with_multiple_networks(self, mock_loader):
79+
"""Should process ethereum and polygon separately"""
80+
watermark = ResumeWatermark(
81+
ranges=[
82+
BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'),
83+
BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'),
84+
]
85+
)
86+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
87+
mock_loader._handle_reorg = Mock()
88+
89+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
90+
91+
assert mock_loader._handle_reorg.call_count == 2
92+
93+
first_call = mock_loader._handle_reorg.call_args_list[0]
94+
assert first_call[0][0][0].network == 'ethereum'
95+
assert first_call[0][0][0].start == 1011
96+
97+
second_call = mock_loader._handle_reorg.call_args_list[1]
98+
assert second_call[0][0][0].network == 'polygon'
99+
assert second_call[0][0][0].start == 2011
100+
101+
def test_rewind_uses_default_connection_name(self, mock_loader):
102+
"""Should use default connection name from loader class"""
103+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
104+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
105+
mock_loader._handle_reorg = Mock()
106+
107+
mock_loader._rewind_to_watermark('test_table', connection_name=None)
108+
109+
mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table')

0 commit comments

Comments
 (0)