Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Fixed
- Fix issue where user-specified `color_continuous_scale` was ignored when template had `autocolorscale=True` [[#5439](https://github.com/plotly/plotly.py/pull/5439)]

### Updated
- Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution!

Expand Down
12 changes: 11 additions & 1 deletion plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2486,6 +2486,11 @@ def get_groups_and_orders(args, grouper):
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
# Track if color_continuous_scale was explicitly provided by user
# (before apply_default_cascade fills it from template/defaults)
user_provided_colorscale = (
"color_continuous_scale" in args and args["color_continuous_scale"] is not None
)
apply_default_cascade(args)

args = build_dataframe(args, constructor)
Expand Down Expand Up @@ -2704,7 +2709,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
range_color = args["range_color"] or [None, None]

colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
layout_patch["coloraxis1"] = dict(
coloraxis_dict = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
Expand All @@ -2715,6 +2720,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
title_text=get_decorated_label(args, args[colorvar], colorvar)
),
)
# Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template
# that sets autocolorscale=True would override the user provided colorscale.
if user_provided_colorscale:
coloraxis_dict["autocolorscale"] = False
layout_patch["coloraxis1"] = coloraxis_dict
for v in ["height", "width"]:
if args[v]:
layout_patch[v] = args[v]
Expand Down
12 changes: 11 additions & 1 deletion plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def imshow(
axes labels and ticks.
"""
args = locals()
# Track if color_continuous_scale was explicitly provided by user
# (before apply_default_cascade fills it from template/defaults)
user_provided_colorscale = (
"color_continuous_scale" in args and args["color_continuous_scale"] is not None
)
apply_default_cascade(args)
labels = labels.copy()
nslices_facet = 1
Expand Down Expand Up @@ -419,14 +424,19 @@ def imshow(
layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
layout["yaxis"]["constrain"] = "domain"
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
layout["coloraxis1"] = dict(
coloraxis_dict = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=color_continuous_midpoint,
cmin=zmin,
cmax=zmax,
)
# Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template
# that sets autocolorscale=True would override the user provided colorscale.
if user_provided_colorscale:
coloraxis_dict["autocolorscale"] = False
layout["coloraxis1"] = coloraxis_dict
if labels["color"]:
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])

Expand Down
21 changes: 21 additions & 0 deletions tests/test_optional/test_px/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,24 @@ def test_color_categorical_dtype():
px.scatter(
df[df.day != df.day.cat.categories[0]], x="total_bill", y="tip", color="day"
)


def test_color_continuous_scale_autocolorscale():
# User-provided colorscale should override template autocolorscale=True
fig = px.scatter(
x=[1, 2],
y=[1, 2],
color=[1, 2],
color_continuous_scale="Viridis",
template=dict(layout_coloraxis_autocolorscale=True),
)
assert fig.layout.coloraxis1.autocolorscale is False

# Without user-provided colorscale, template autocolorscale should be respected
fig2 = px.scatter(
x=[1, 2],
y=[1, 2],
color=[1, 2],
template=dict(layout_coloraxis_autocolorscale=True),
)
assert fig2.layout.coloraxis1.autocolorscale is None
17 changes: 17 additions & 0 deletions tests/test_optional/test_px/test_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def test_colorscale():
assert fig.layout.coloraxis1.colorscale[0] == (0.0, "#440154")


def test_imshow_color_continuous_scale_autocolorscale():
# User-provided colorscale should override template autocolorscale=True
fig = px.imshow(
img_gray,
color_continuous_scale="Viridis",
template=dict(layout_coloraxis_autocolorscale=True),
)
assert fig.layout.coloraxis1.autocolorscale is False

# Without user-provided colorscale, template autocolorscale should be respected
fig2 = px.imshow(
img_gray,
template=dict(layout_coloraxis_autocolorscale=True),
)
assert fig2.layout.coloraxis1.autocolorscale is None


def test_wrong_dimensions():
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
msg = "px.imshow only accepts 2D single-channel, RGB or RGBA images."
Expand Down