Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion semhash/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def prepare_records(
dict_records_typed: list[dict[str, Any]] = list(records)
dict_records = []
for record in dict_records_typed:
coerced: dict[str, Any] = {}
# Start with a copy of the full record to preserve non-embedding fields
coerced: dict[str, Any] = dict(record)
# Then coerce only the embedding columns
for column in columns:
val = record.get(column)
if val is None:
Expand Down
14 changes: 8 additions & 6 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,14 @@ def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[

dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment]
result: list[dict[str, Any]] = []
for r in dict_records:
out = {}
for c in self.columns:
if (val := r.get(c)) is None:
raise ValueError(f"Column '{c}' has None value in record {r}")
out[c] = coerce_value(val)
for record in dict_records:
# Start with a copy of the full record to preserve non-embedding fields
out = dict(record)
# Then coerce only the embedding columns
for col in self.columns:
if (val := record.get(col)) is None:
raise ValueError(f"Column '{col}' has None value in record {record}")
out[col] = coerce_value(val)
result.append(out)
return result

Expand Down
2 changes: 1 addition & 1 deletion semhash/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_triple__ = (0, 4, 0) # pragma: no cover
__version_triple__ = (0, 4, 1) # pragma: no cover
__version__ = ".".join(map(str, __version_triple__)) # pragma: no cover
35 changes: 35 additions & 0 deletions tests/test_semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,41 @@ def test_from_records_edge_cases(model: Encoder) -> None:
SemHash.from_records([{"text": "apple"}, {"text": None}], columns=["text"], model=model)


def test_preserve_non_embedding_fields(model: Encoder) -> None:
"""Test that fields not specified in columns are preserved in results."""
records = [
{"id": 0, "text": "triforce", "metadata": "game1"},
{"id": 1, "text": "master sword", "metadata": "game2"},
{"id": 2, "text": "hylian shield", "metadata": "game3"},
]
semhash = SemHash.from_records(records, columns=["text"], model=model)

# Test self_deduplicate preserves non-embedding fields
result = semhash.self_deduplicate(threshold=0.9)
assert len(result.selected) == 3, "All records should be unique"

# All results should have id and metadata fields preserved
for record in result.selected:
assert "id" in record, "id field should be preserved"
assert "text" in record, "text field should be preserved"
assert "metadata" in record, "metadata field should be preserved"

# Check specific values are correct
ids = {r["id"] for r in result.selected}
assert ids == {0, 1, 2}, "All id values should be preserved"

metadatas = {r["metadata"] for r in result.selected}
assert metadatas == {"game1", "game2", "game3"}, "All metadata values should be preserved"

# Test that cross-dataset deduplication also preserves fields
new_records = [{"id": 10, "text": "triforce", "metadata": "duplicate"}]
dup_result = semhash.deduplicate(new_records, threshold=0.9)

assert len(dup_result.filtered) == 1, "Should detect duplicate"
assert "id" in dup_result.filtered[0].record, "id should be preserved in filtered records"
assert dup_result.filtered[0].record["id"] == 10, "Correct id value"


def test_deduplicate_edge_cases(model: Encoder) -> None:
"""Test deduplicate() edge cases: coercion, None rejection, empty records, type mismatches."""
semhash = SemHash.from_records(["1", "2", "3"], model=model)
Expand Down