diff --git a/CHANGELOG.md b/CHANGELOG.md index afd0a77ebc8..f4bbefada7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Fixed +- Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)] + ### Updated - Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution! diff --git a/plotly/express/_core.py b/plotly/express/_core.py index d2dbc84c0e7..dde9e33888a 100644 --- a/plotly/express/_core.py +++ b/plotly/express/_core.py @@ -2486,6 +2486,11 @@ def get_groups_and_orders(args, grouper): def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_patch = trace_patch or {} layout_patch = layout_patch or {} + # Track if color_continuous_scale was explicitly provided by user + # (before apply_default_cascade fills it from template/defaults) + user_provided_colorscale = ( + "color_continuous_scale" in args and args["color_continuous_scale"] is not None + ) apply_default_cascade(args) args = build_dataframe(args, constructor) @@ -2704,7 +2709,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): range_color = args["range_color"] or [None, None] colorscale_validator = ColorscaleValidator("colorscale", "make_figure") - layout_patch["coloraxis1"] = dict( + coloraxis_dict = dict( colorscale=colorscale_validator.validate_coerce( args["color_continuous_scale"] ), @@ -2715,6 +2720,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): title_text=get_decorated_label(args, args[colorvar], colorvar) ), ) + # Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template + # that sets autocolorscale=True would override the user provided colorscale. + if user_provided_colorscale: + coloraxis_dict["autocolorscale"] = False + layout_patch["coloraxis1"] = coloraxis_dict for v in ["height", "width"]: if args[v]: layout_patch[v] = args[v] diff --git a/plotly/express/_imshow.py b/plotly/express/_imshow.py index ce6ddb84286..ddf83cf3bc2 100644 --- a/plotly/express/_imshow.py +++ b/plotly/express/_imshow.py @@ -233,6 +233,11 @@ def imshow( axes labels and ticks. """ args = locals() + # Track if color_continuous_scale was explicitly provided by user + # (before apply_default_cascade fills it from template/defaults) + user_provided_colorscale = ( + "color_continuous_scale" in args and args["color_continuous_scale"] is not None + ) apply_default_cascade(args) labels = labels.copy() nslices_facet = 1 @@ -419,7 +424,7 @@ def imshow( layout["xaxis"] = dict(scaleanchor="y", constrain="domain") layout["yaxis"]["constrain"] = "domain" colorscale_validator = ColorscaleValidator("colorscale", "imshow") - layout["coloraxis1"] = dict( + coloraxis_dict = dict( colorscale=colorscale_validator.validate_coerce( args["color_continuous_scale"] ), @@ -427,6 +432,11 @@ def imshow( cmin=zmin, cmax=zmax, ) + # Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template + # that sets autocolorscale=True would override the user provided colorscale. + if user_provided_colorscale: + coloraxis_dict["autocolorscale"] = False + layout["coloraxis1"] = coloraxis_dict if labels["color"]: layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) diff --git a/tests/test_optional/test_px/test_colors.py b/tests/test_optional/test_px/test_colors.py index 8f6e599d88a..d41dc7aa6e5 100644 --- a/tests/test_optional/test_px/test_colors.py +++ b/tests/test_optional/test_px/test_colors.py @@ -60,3 +60,24 @@ def test_color_categorical_dtype(): px.scatter( df[df.day != df.day.cat.categories[0]], x="total_bill", y="tip", color="day" ) + + +def test_color_continuous_scale_autocolorscale(): + # User-provided colorscale should override template autocolorscale=True + fig = px.scatter( + x=[1, 2], + y=[1, 2], + color=[1, 2], + color_continuous_scale="Viridis", + template=dict(layout_coloraxis_autocolorscale=True), + ) + assert fig.layout.coloraxis1.autocolorscale is False + + # Without user-provided colorscale, template autocolorscale should be respected + fig2 = px.scatter( + x=[1, 2], + y=[1, 2], + color=[1, 2], + template=dict(layout_coloraxis_autocolorscale=True), + ) + assert fig2.layout.coloraxis1.autocolorscale is None diff --git a/tests/test_optional/test_px/test_imshow.py b/tests/test_optional/test_px/test_imshow.py index 86e843a6ff9..bd75e8374de 100644 --- a/tests/test_optional/test_px/test_imshow.py +++ b/tests/test_optional/test_px/test_imshow.py @@ -98,6 +98,23 @@ def test_colorscale(): assert fig.layout.coloraxis1.colorscale[0] == (0.0, "#440154") +def test_imshow_color_continuous_scale_autocolorscale(): + # User-provided colorscale should override template autocolorscale=True + fig = px.imshow( + img_gray, + color_continuous_scale="Viridis", + template=dict(layout_coloraxis_autocolorscale=True), + ) + assert fig.layout.coloraxis1.autocolorscale is False + + # Without user-provided colorscale, template autocolorscale should be respected + fig2 = px.imshow( + img_gray, + template=dict(layout_coloraxis_autocolorscale=True), + ) + assert fig2.layout.coloraxis1.autocolorscale is None + + def test_wrong_dimensions(): imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)] msg = "px.imshow only accepts 2D single-channel, RGB or RGBA images."