diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index cfafb58a..c78ffda2 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -5006,12 +5006,18 @@ def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic): def add_new_topics(self, mappings: Mapping[int, int]): """Add new row(s) of topic mappings. + New topics did not exist at earlier states, so the intermediate + history columns are backfilled with the topic's own ``key`` to + keep ``mappings_`` a homogeneous integer matrix. Without this, + ``None`` placeholders break ``model.save(serialization="safetensors")`` + which casts the matrix to ``np.array(..., dtype=int)``. + Arguments: mappings: The mappings to add """ length = len(self.mappings_[0]) for key, value in mappings.items(): - to_append = [key] + ([None] * (length - 2)) + [value] + to_append = [key] * (length - 1) + [value] self.mappings_.append(to_append) diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 0d315f89..68b9e36c 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -1,6 +1,8 @@ import copy import pytest +import numpy as np from bertopic import BERTopic +from bertopic._bertopic import TopicMapper import importlib.util @@ -153,3 +155,34 @@ def test_full_model(model, documents, request): merged_model = BERTopic.merge_models([topic_model, topic_model1]) assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info()) + + +def test_topic_mapper_add_new_topics_keeps_integer_matrix(): + """Regression test for #2432: ``TopicMapper.add_new_topics`` must keep + ``mappings_`` as a homogeneous integer matrix. + + Previously, ``add_new_topics`` inserted ``None`` placeholders for the + intermediate history columns, which broke ``model.save()`` because + ``_save_utils.save_topics`` casts the matrix to ``np.array(..., dtype=int)``. + """ + mapper = TopicMapper(topics=[-1, 0, 1, 2]) + # Simulate two prior reduce_topics calls so the matrix has more than 2 + # columns (the buggy ``length - 2`` path is hidden when ``__init__``'s + # default 2-column shape is used). + for row in mapper.mappings_: + row.append(row[-1]) + row.append(row[-1]) + pre_existing = [list(row) for row in mapper.mappings_] + + # New clusters discovered during partial_fit + mapper.add_new_topics({3: 2, 4: 3}) + + # The matrix must round-trip through ``np.array(..., dtype=int)`` + # (mirrors what ``_save_utils.save_topics`` does). + matrix = np.array(mapper.mappings_, dtype=int) + # Pre-existing rows must be untouched. + assert mapper.mappings_[: len(pre_existing)] == pre_existing + # Original and current state of new rows must be preserved, and the + # intermediate history columns must be backfilled with the topic's own key. + assert (matrix[-2, :-1] == 3).all() and matrix[-2, -1] == 2 + assert (matrix[-1, :-1] == 4).all() and matrix[-1, -1] == 3