Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion drevalpy/models/baselines/sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def build_model(self, hyperparameters: dict):
Builds the model from hyperparameters.

:param hyperparameters: Hyperparameters for the model. Contains n_estimators, criterion, max_samples,
and n_jobs.
max_depth and n_jobs.
"""
super().build_model(hyperparameters)
if self.hyperparameters["max_depth"] == "None":
Expand All @@ -344,6 +344,7 @@ def build_model(self, hyperparameters: dict):
n_estimators=self.hyperparameters["n_estimators"],
criterion=self.hyperparameters["criterion"],
max_samples=self.hyperparameters["max_samples"],
max_depth=self.hyperparameters["max_depth"],
n_jobs=self.hyperparameters["n_jobs"],
)

Expand Down
26 changes: 25 additions & 1 deletion tests/models/test_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,34 @@
NaiveTissueDrugMeanPredictor,
NaiveTissueMeanPredictor,
)
from drevalpy.models.baselines.sklearn_models import SklearnModel
from drevalpy.models.baselines.sklearn_models import RandomForest, SklearnModel
from drevalpy.models.drp_model import DRPModel


@pytest.mark.parametrize("max_depth_input, expected", [(5, 5), (10, 10), (30, 30), ("None", None)])
def test_random_forest_respects_max_depth(max_depth_input, expected) -> None:
"""Ensure RandomForest forwards max_depth to the underlying RandomForestRegressor.

Regression test: max_depth was read from the hyperparameters but never passed to the
RandomForestRegressor constructor, so every forest was built with the default max_depth=None
regardless of the configured value.

:param max_depth_input: max_depth value as provided via the hyperparameters
:param expected: max_depth expected on the built sklearn model
"""
model = RandomForest()
model.build_model(
{
"n_estimators": 10,
"criterion": "squared_error",
"max_samples": 0.5,
"n_jobs": 1,
"max_depth": max_depth_input,
}
)
assert model.model.max_depth == expected


@pytest.mark.parametrize(
"model_name",
[
Expand Down
Loading