diff --git a/drevalpy/models/baselines/sklearn_models.py b/drevalpy/models/baselines/sklearn_models.py index 0a00286f..42d91dd2 100644 --- a/drevalpy/models/baselines/sklearn_models.py +++ b/drevalpy/models/baselines/sklearn_models.py @@ -335,7 +335,7 @@ def build_model(self, hyperparameters: dict): Builds the model from hyperparameters. :param hyperparameters: Hyperparameters for the model. Contains n_estimators, criterion, max_samples, - and n_jobs. + max_depth and n_jobs. """ super().build_model(hyperparameters) if self.hyperparameters["max_depth"] == "None": @@ -344,6 +344,7 @@ def build_model(self, hyperparameters: dict): n_estimators=self.hyperparameters["n_estimators"], criterion=self.hyperparameters["criterion"], max_samples=self.hyperparameters["max_samples"], + max_depth=self.hyperparameters["max_depth"], n_jobs=self.hyperparameters["n_jobs"], ) diff --git a/tests/models/test_baselines.py b/tests/models/test_baselines.py index 7168713d..410bb8e8 100644 --- a/tests/models/test_baselines.py +++ b/tests/models/test_baselines.py @@ -20,10 +20,34 @@ NaiveTissueDrugMeanPredictor, NaiveTissueMeanPredictor, ) -from drevalpy.models.baselines.sklearn_models import SklearnModel +from drevalpy.models.baselines.sklearn_models import RandomForest, SklearnModel from drevalpy.models.drp_model import DRPModel +@pytest.mark.parametrize("max_depth_input, expected", [(5, 5), (10, 10), (30, 30), ("None", None)]) +def test_random_forest_respects_max_depth(max_depth_input, expected) -> None: + """Ensure RandomForest forwards max_depth to the underlying RandomForestRegressor. + + Regression test: max_depth was read from the hyperparameters but never passed to the + RandomForestRegressor constructor, so every forest was built with the default max_depth=None + regardless of the configured value. + + :param max_depth_input: max_depth value as provided via the hyperparameters + :param expected: max_depth expected on the built sklearn model + """ + model = RandomForest() + model.build_model( + { + "n_estimators": 10, + "criterion": "squared_error", + "max_samples": 0.5, + "n_jobs": 1, + "max_depth": max_depth_input, + } + ) + assert model.model.max_depth == expected + + @pytest.mark.parametrize( "model_name", [