Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569)#884
Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569)#884jeongyoonlee wants to merge 5 commits intomasterfrom
Conversation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
#569) Bootstrap sampling can exclude entire treatment groups from a tree's training data, causing individual trees to produce prediction arrays of different widths. When summing predictions across trees, this causes a ValueError for shape mismatch. Added _align_tree_predict() that maps each tree's predictions to the forest-level class ordering, filling zeros for missing treatment groups. This is a module-level function (not a closure) so it works with joblib's parallel pickling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Fixes UpliftRandomForestClassifier.predict() failing with shape-mismatch errors when bootstrap sampling yields trees trained without all treatment groups, by aligning per-tree prediction outputs to the forest-level class ordering.
Changes:
- Add
_align_tree_predict()helper to map each tree’s predictions into the forest’sclasses_layout, filling missing treatment columns with zeros. - Use
_align_tree_predict()in both parallel (joblib) and non-parallel prediction paths. - Update
.gitignoreentries for Claude-related worktree artifacts.
Reviewed changes
Copilot reviewed 1 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
causalml/inference/tree/uplift.pyx |
Adds per-tree prediction alignment to prevent ensemble summation broadcast errors when some trees lack treatment groups. |
.gitignore |
Adjusts ignored Claude/worktree artifact paths (currently with a duplicated entry). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
causalml/inference/tree/uplift.pyx
Outdated
| for tree_idx, cls in enumerate(tree.classes_): | ||
| if cls in forest_classes: | ||
| forest_idx = forest_classes.index(cls) |
There was a problem hiding this comment.
_align_tree_predict does an O(k^2) alignment by repeatedly checking cls in forest_classes and calling forest_classes.index(cls) inside the loop. This can become a noticeable overhead when there are many treatment groups and/or many trees. Consider precomputing a {class_label: forest_idx} mapping once per forest predict call (and passing it into the helper) or building it once per helper invocation, then doing direct lookups; this also makes the logic clearer and avoids repeated linear scans.
| for tree_idx, cls in enumerate(tree.classes_): | |
| if cls in forest_classes: | |
| forest_idx = forest_classes.index(cls) | |
| # Precompute mapping from class label to forest index to avoid repeated | |
| # linear scans over forest_classes. | |
| class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)} | |
| for tree_idx, cls in enumerate(tree.classes_): | |
| forest_idx = class_to_forest_idx.get(cls) | |
| if forest_idx is not None: |
causalml/inference/tree/uplift.pyx
Outdated
| raw = tree.predict(X=X) | ||
| if len(tree.classes_) == len(forest_classes): | ||
| return raw | ||
| aligned = np.zeros((raw.shape[0], len(forest_classes))) |
There was a problem hiding this comment.
aligned = np.zeros((raw.shape[0], len(forest_classes))) will always default to float64, potentially changing dtype vs the tree output and doing unnecessary upcasting. Consider setting dtype=raw.dtype so the aligned array matches the per-tree prediction dtype.
| aligned = np.zeros((raw.shape[0], len(forest_classes))) | |
| aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype) |
causalml/inference/tree/uplift.pyx
Outdated
| if self.n_jobs != 1: | ||
| y_pred_ensemble = sum( | ||
| Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer) | ||
| (delayed(tree.predict)(X=X) for tree in self.uplift_forest) | ||
| (delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest) | ||
| ) / len(self.uplift_forest) | ||
| else: | ||
| y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest) | ||
| y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest) |
There was a problem hiding this comment.
There’s no targeted regression test ensuring that UpliftRandomForestClassifier.predict() works when bootstrap sampling causes at least one tree to be fit without some treatment groups (the scenario from #569). Current tests use large, balanced samples (see tests/const.py), which makes it unlikely for bootstraps to drop a group. Consider adding a small multi-treatment dataset (or a deterministic bootstrap seed) that reproduces the missing-group case and asserts predict() succeeds for both n_jobs=1 and n_jobs>1.
- Use dict for O(1) class-to-index mapping instead of repeated list scans - Preserve dtype with dtype=raw.dtype in aligned array - Add test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Thanks for the review. All three suggestions have been addressed:
|
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 3 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype) | ||
| class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)} | ||
| for tree_idx, cls in enumerate(tree.classes_): | ||
| forest_idx = class_to_forest_idx.get(cls) | ||
| if forest_idx is not None: | ||
| aligned[:, forest_idx] = raw[:, tree_idx] |
There was a problem hiding this comment.
class_to_forest_idx is rebuilt for every tree prediction call. Consider constructing this mapping once in UpliftRandomForestClassifier.predict() and passing it (or passing precomputed forest indices for tree.classes_) to reduce overhead when n_estimators or number of treatment groups is large.
- Build class_to_forest_idx dict once in predict() instead of per tree - Use model.n_jobs instead of parallel_backend for parallel test - Assert that sparse-group condition actually occurred in test Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Addressed the new feedback:
|
With only 1 sample per minority treatment group out of 102 total, bootstrap sampling will miss them in most trees, making the test deterministic regardless of seed or CI environment. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Fixed the probabilistic test failure: reduced minority treatment groups from 5 samples each to 1 sample each (out of 102 total). With bootstrap sampling drawing n=102 with replacement, each tree has a ~37% chance of missing a 1-sample group. Across 10 trees and 2 minority groups, the probability of the sparse-group condition not occurring is ~0.01%, making the test effectively deterministic. |
Summary
ValueError: operands could not be broadcast together with shapes (N,4) (N,3)_align_tree_predict()module-level function that maps each tree's predictions to the forest-level class ordering, filling zeros for missing groupsTest plan
pytest tests/test_uplift_trees.py— 23 passed🤖 Generated with Claude Code