diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index f2033cf359a..623e976dc2d 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -391,6 +391,9 @@ class is a subclass of both BaseFigure and widgets.DOMWidget. self._animation_duration_validator = animation.DurationValidator() self._animation_easing_validator = animation.EasingValidator() + # Space for auxiliary data + self._aux = dict() + # Template # -------- # ### Check for default template ### diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 7ad2fb4eb01..727ecd687a8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2057,6 +2057,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) + # store args in figure metadata + fig._aux["px"] = dict(args=args) return fig diff --git a/proto/px_overlay/README.md b/proto/px_overlay/README.md new file mode 100644 index 00000000000..a64dc2ea733 --- /dev/null +++ b/proto/px_overlay/README.md @@ -0,0 +1,15 @@ +# `px.overlay` prototype + +This demonstrates one possible way of combining two figures into a single +figure. + +To see an example, run (from the root of the `plotly.py` repo): + +```bash +PYTHONPATH=proto/px_overlay python proto/px_overlay/multilayered_data_test.py +``` + +To see the code that does the overlaying, start with the `px_simple_overlay` +function in `proto/px_overlay/px_overlay.py`. In this function there are a few +comments marked with `TODO` that indicate places for improvement in the +functionality. diff --git a/proto/px_overlay/facet_col_wrap_test.py b/proto/px_overlay/facet_col_wrap_test.py new file mode 100644 index 00000000000..ae41ae67c61 --- /dev/null +++ b/proto/px_overlay/facet_col_wrap_test.py @@ -0,0 +1,27 @@ +import plotly.express as px +import test_data +from px_combine import px_simple_combine + +df = test_data.multilayered_data(d_divs=[6, 3, 2], rwalk=0.1) +last_cat = df.columns[2] +last_cat_types = list(set(df[last_cat])) +fig0 = px.line( + df.loc[df[last_cat] == last_cat_types[0]], + x="x", + y="y", + facet_col=df.columns[0], + facet_col_wrap=3, + color=df.columns[1], +).update_layout(title="%s=%s" % (last_cat, last_cat_types[0])) +fig1 = px.line( + df.loc[df[last_cat] == last_cat_types[1]], + x="x", + y="y", + facet_col=df.columns[0], + facet_col_wrap=3, + color=df.columns[1], +).update_layout(title="%s=%s" % (last_cat, last_cat_types[1])) +fig = px_simple_combine(fig0, fig1, fig1_secondary_y=True) +fig0.show() +fig1.show() +fig.show() diff --git a/proto/px_overlay/find_field.py b/proto/px_overlay/find_field.py new file mode 100644 index 00000000000..c173ae05107 --- /dev/null +++ b/proto/px_overlay/find_field.py @@ -0,0 +1,43 @@ +import plotly.graph_objects as go +from plotly import basedatatypes + +# Search down an object's composition tree and find fields with a given name + + +def find_field(obj, field, basepath="", max_path_len=80, forbidden=["parent"]): + if obj is not None and len(basepath) < max_path_len: + for f in dir(obj): + joined_path = ".".join([basepath, f]) + if f == field: + print(joined_path) + if ( + (f not in forbidden) + and (not f.startswith("_")) + and (not f.endswith("_")) + ): + find_field(eval("obj.%s" % (f,)), field, joined_path) + + +def find_all_xy_traces(): + for field in dir(go): + call_str = "go.%s" % (field,) + call = eval(call_str) + try: + if issubclass(call, basedatatypes.BaseTraceType): + obj = call() + if "xaxis" in obj and "yaxis" in obj: + yield (call_str) + except TypeError: + pass + + +# s=go.Scatter() +# s=go.Bar() +# find_field(s,"color",basepath="scatter") +# print() +# find_field(s,"color",basepath="bar") + +for call_str in find_all_xy_traces(): + call = eval(call_str) + find_field(call(), "color", basepath=call_str) + print() diff --git a/proto/px_overlay/map_axis_pair_example.py b/proto/px_overlay/map_axis_pair_example.py new file mode 100644 index 00000000000..a22c26136dd --- /dev/null +++ b/proto/px_overlay/map_axis_pair_example.py @@ -0,0 +1,29 @@ +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import px_overlay +import pytest + +fig0 = px_overlay.make_subplots_all_secondary_y(3, 4) +fig1 = px_overlay.make_subplots_all_secondary_y(4, 5) + +for dims, f in zip([(3, 4), (4, 5)], [fig0, fig1]): + for r, c in px_overlay.multi_index(*dims): + for sy in [False, True]: + f.add_trace(go.Scatter(x=[], y=[]), row=r + 1, col=c + 1, secondary_y=sy) + +fig0.add_annotation(row=2, col=3, text="hi", x=0.25, xref="x domain", y=3) +fig0.add_annotation( + row=3, col=4, text="hi", x=0.25, xref="x domain", y=2, secondary_y=True +) + +for an in fig0.layout.annotations: + oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]]) + newaxpair = px_overlay.map_axis_pair(fig0, fig1, oldaxpair) + newan = go.layout.Annotation(an) + print(oldaxpair) + print(newaxpair) + newan["xref"], newan["yref"] = newaxpair + fig1.add_annotation(newan) + +fig0.show() +fig1.show() diff --git a/proto/px_overlay/multilayered_data_test.py b/proto/px_overlay/multilayered_data_test.py new file mode 100644 index 00000000000..6f59c2dcfaa --- /dev/null +++ b/proto/px_overlay/multilayered_data_test.py @@ -0,0 +1,56 @@ +import test_data +import numpy as np +import plotly.express as px +from px_overlay import px_simple_overlay + +# Demonstrates px_overlay prototype. + +# Make some data that can be faceted by row, col and color, and split into 2 +# sets, which will go to the first and second figure respectively. +df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1) + +# The titles of the figures use the last dimension in the data. The title is +# formatted "column_name=column_value", so here we extract the column name. +last_cat = df.columns[3] +figs = [] +for px_call, last_cat_0 in zip([px.line, px.bar], list(set(df[last_cat]))): + # px_call is the chart type to make and last_cat_0 is the column_value for + # that figure which is used in forming the title. + df_slice = df.loc[df[last_cat] == last_cat_0] + fig = px_call( + df_slice, + x="x", + y="y", + facet_row=df.columns[0], + facet_col=df.columns[1], + color=df.columns[2], + ) + + fig.update_layout(title="%s=%s" % (last_cat, last_cat_0,)) + figs.append(fig) + +# Add some annotations to make sure they are copied to the final figure properly +figs[0].add_hline(y=1, row=1, col="all") +figs[0].add_annotation( + x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo" +) +# Note that these annotations should be mapped to a secondary y axis (observe this +# in the final figure by dragging their corresponding secondary y axes). +figs[1].add_vline(x=10, row="all", col=2) +figs[1].add_annotation( + x=0.5, y=0.35, xref="x domain", yref="y", row=1, col=2, text="budday" +) +# Set the bar modes for both to see that the first figure that the barmode for +# the final figure will be taken from the figure that has bars. +figs[0].layout.barmode = "group" +figs[1].layout.barmode = "relative" + +# overlay the figures +final_fig = px_simple_overlay(*figs, fig1_secondary_y=True) + +# Show the initial figures +for fig in figs: + fig.show() + +# Show the final figure +final_fig.show() diff --git a/proto/px_overlay/px_overlay.py b/proto/px_overlay/px_overlay.py new file mode 100644 index 00000000000..3ebe32330ed --- /dev/null +++ b/proto/px_overlay/px_overlay.py @@ -0,0 +1,360 @@ +# Prototype for px.overlay +# Combine 2 figures containing subplots +# Run as +# python px_overlay.py + +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import test_data +import json +from itertools import product, cycle, chain +from functools import reduce +import re + + +def multi_index(*kwargs): + return product(*[range(k) for k in kwargs]) + + +def extract_axes(layout): + ret = dict() + for k in dir(layout): + if k[1 : 1 + len("axis")] == "axis": + ret[k] = layout[k] + return ret + + +def fig_grid_ref_shape(fig): + grid_ref = fig._validate_get_grid_ref() + return (len(grid_ref), len(grid_ref[0])) + + +def fig_subplot_axes(fig, r, c): + grid_ref = fig._validate_get_grid_ref() + return [fig.layout[k] for k in grid_ref[r - 1][c - 1][0].layout_keys] + + +def extract_axis_titles(fig): + """ + Given figure created using make_subplots, with r rows and c columns, return + r titles from the x axes and y titles from the y axes. + """ + grid_ref_shape = fig_grid_ref_shape(fig) + r_titles = [ + fig_subplot_axes(fig, r + 1, 1)[1]["title"] for r in range(grid_ref_shape[0]) + ] + c_titles = [ + fig_subplot_axes(fig, 1, c + 1)[0]["title"] for c in range(grid_ref_shape[1]) + ] + return (r_titles, c_titles) + + +def make_subplots_all_secondary_y(rows, cols): + """ + Get subplots like make_subplots but all also have secondary y-axes. + """ + grid_ref_shape = [rows, cols] + specs = [ + [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] + for _ in range(grid_ref_shape[0]) + ] + fig = make_subplots(*grid_ref_shape, specs=specs) + return fig + + +def parse_axis_ref(ax): + """ Find the axis letter, optional number, and domain of axis. """ + # TODO: can this be obtained via codegen? + pat = re.compile("([xy])(axis)?([0-9]*)( domain)?") + matches = pat.match(ax) + if matches is None: + raise ValueError('Axis "%s" cannot be parsed.' % (ax,)) + return (matches[1], matches[3], matches[4]) + + +def norm_axis_ref(ax): + """ normalize ax so it is in the format: yaxis, yaxis2, xaxis7 etc. """ + al, an, _ = parse_axis_ref(ax) + return al + "axis" + an + + +def axis_pair_to_row_col(fig, axpair): + """ + returns the row and column of the subplot having the axis pair and whether it is a + secondary y + """ + if "paper" in axpair: + raise ValueError('Cannot find row and column of "paper" axis reference.') + naxpair = tuple([norm_axis_ref(ax) for ax in axpair]) + nrows, ncols = fig_grid_ref_shape(fig) + row = None + col = None + for r, c in multi_index(nrows, ncols): + for sp in fig._grid_ref[r][c]: + if naxpair == sp.layout_keys: + row = r + 1 + col = c + 1 + if row is None or col is None: + raise ValueError("Could not find subplot containing axes (%s,%s)." % nax) + secondary_y = False + yax = naxpair[1] + if fig.layout[yax]["side"] == "right": + secondary_y = True + return (row, col, secondary_y) + + +def find_subplot_axes(fig, row, col, secondary_y=False): + """ + Returns 2-tuple containing (xaxis,yaxis) at specified row, col and secondary y-axis. + """ + nrows, ncols = fig_grid_ref_shape(fig) + try: + sps = fig._grid_ref[row - 1][col - 1] + except (IndexError, TypeError): + # IndexError if fig has _grid_ref but not requested row or column, + # TypeError if fig has no _grid_ref (it is None) + raise IndexError( + "Figure does not have a subplot at the requested row or column." + ) + + def _check_is_secondary_y(sp): + xax, yax = sp.layout_keys + # TODO: It may not be totally accurate to assume if an y-axis' "side" is + # "right" than it is a secondary y axis... + return fig.layout[yax]["side"] == "right" + + # find the secondary y axis + err_msg = ( + "Could not find a y-axis " "at the subplot in the requested row or column." + ) + filter_fun = lambda sp: not _check_is_secondary_y(sp) + if secondary_y: + err_msg = ( + "Could not find a secondary y-axis " + "at the subplot in the requested row or column." + ) + filter_fun = _check_is_secondary_y + try: + sp = list(filter(filter_fun, sps))[0] + except (IndexError, TypeError): + # Catch IndexError if the list is empty, catch TypeError if sps isn't + # iterable (e.g., is None) + raise IndexError(err_msg) + return sp.layout_keys + + +def map_axis_pair( + old_fig, + new_fig, + axpair, + new_row=None, + new_col=None, + new_secondary_y=None, + make_axis_ref=True, +): + """ + Find the axes on the new figure that will give the same subplot and + possibly secondary y axis as on the old figure. This can only + work if the axis pair is ("paper","paper") or the axis pair corresponds to a + subplot on the old figure the new figure has corresponding rows, + columns and secondary y-axes. + if make_axis_ref is True, axis is removed from the resulting strings, e.g., xaxis2 -> x2 + """ + if None in axpair: + raise ValueError("Cannot map axis whose value is None.") + if axpair == ("paper", "paper"): + return axpair + row, col, secondary_y = axis_pair_to_row_col(old_fig, axpair) + row = new_row if new_row is not None else row + col = new_col if new_col is not None else col + secondary_y = new_secondary_y if new_secondary_y is not None else secondary_y + newaxpair = find_subplot_axes(new_fig, row, col, secondary_y) + axpair_extras = [" domain" if ax.endswith("domain") else "" for ax in axpair] + newaxpair = tuple(ax + extra for ax, extra in zip(newaxpair, axpair_extras)) + if make_axis_ref: + newaxpair = tuple(ax.replace("axis", "") for ax in newaxpair) + return newaxpair + + +def map_annotation_like_obj_axis(oldfig, newfig, an, force_secondary_y=False): + """ + Take an annotation-like object with xref and yref referring to axes in oldfig + and map them to axes in newfig. This makes it possible to map an annotation + to the same subplot row, column or secondary y in a new plot even if they do + not have matching subplots. + If force_secondary_y is True, attempt is made to map the annotation to a + secondary y axis in the new figure. + Returns the new annotation. Note that it has not been added to newfig, the + caller must then do this if it wants it added to newfig. + """ + oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]]) + newaxpair = map_axis_pair( + oldfig, newfig, oldaxpair, new_secondary_y=force_secondary_y + ) + newan = an.__class__(an) + newan["xref"], newan["yref"] = newaxpair + return newan + + +def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): + """ + Combines two figures by putting all the traces from fig0 and fig1 on a new + figure (fig). Then the annotation-like objects are copied to fig (i.e., the + titles are not copied). + The colors are reassigned so each trace has a unique color until all the + colors in the colorway are exhausted and then loops through the colorway to + assign additional colors (this is referred to as "reflowing" below). + In order to differentiate the traces in the legend, if fig0 or fig1 have + titles, they are prepended to the trace name. + If fig1_secondary_y is True, then the yaxes from fig1 are placed on + secondary y axes in the new figure. + """ + if fig1_secondary_y and ( + ("px" not in fig0._aux.keys()) or ("px" not in fig0._aux.keys()) + ): + raise ValueError( + "To place fig1's traces on secondary y-axes, both figures must have " + "been made with Plotly Express." + ) + grid_ref_shape = fig_grid_ref_shape(fig0) + if grid_ref_shape != fig_grid_ref_shape(fig1): + raise ValueError( + "Only two figures with the same subplot geometry can be overlayed." + ) + # get colors for reflowing + colorway = fig0.layout.template.layout.colorway + specs = None + if fig1_secondary_y: + specs = [ + [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] + for _ in range(grid_ref_shape[0]) + ] + # TODO: This needs to detect the start_cell of the input figures rather than + # assuming 'bottom-left', which is just the px default start_cell + fig = make_subplots( + *fig_grid_ref_shape(fig0), specs=specs, start_cell="bottom-left" + ) + for r, c in multi_index(*fig_grid_ref_shape(fig)): + print("row,col", r + 1, c + 1) + for (tr, f), color in zip( + chain( + *[ + zip(f.select_traces(row=r + 1, col=c + 1), cycle([f]),) + for f in [fig0, fig1] + ] + ), + # reflow the colors + cycle(colorway), + ): + title = f.layout.title.text + set_main_trace_color(tr, color) + # use figure title to differentiate the legend items + tr["name"] = "%s %s" % (title, tr["name"]) + # TODO: argument to group legend items? + tr["legendgroup"] = None + fig.add_trace( + tr, row=r + 1, col=c + 1, secondary_y=(fig1_secondary_y and (f == fig1)) + ) + # TODO: How to preserve axis sizes when adding secondary y? + + # Map the axes of the annotation-like objects to the new figure. Map the + # fig1 objects to the secondary-y if requested. + selectors = product( + [fig0, fig1], + [ + go.Figure.select_annotations, + go.Figure.select_shapes, + go.Figure.select_layout_images, + ], + ) + adders = product( + [(fig, False), (fig, fig1_secondary_y)], + [go.Figure.add_annotation, go.Figure.add_shape, go.Figure.add_layout_image], + ) + for (oldfig, selector), ((newfig, secy), adder) in zip(selectors, adders): + for ann in selector(oldfig): + # TODO this function needs to eventually take into consideration the + # start_cell arguments of the figures involved in the mapping. + newann = map_annotation_like_obj_axis( + oldfig, newfig, ann, force_secondary_y=secy + ) + adder(newfig, newann) + + # fig.update_layout(fig0.layout) + # title will be wrong + fig.layout.title = None + # preserve bar mode + # if both figures have barmode set, the first is taken from the figure that + # has bars (so just the one from fig0 if both have bars), otherwise the set + # one is taken. + # TODO argument to force barmode? or the user can just update it after + fig.layout.barmode = get_first_set_barmode([fig0, fig1]) + return fig + + +def select_all_traces(figs): + traces = list( + reduce( + lambda a, b: a + b, + map(lambda t: list(go.Figure.select_traces(t)), figs), + [], + ) + ) + return traces + + +def check_trace_type_xy(tr): + return ("xaxis" in tr) and ("yaxis" in tr) + + +def check_figs_trace_types_xy(figs): + traces = select_all_traces(figs) + xy_traces = list(map(check_trace_type_xy, traces)) + return xy_traces + + +def set_main_trace_color(tr, color): + # Set the main color of a trace + if type(tr) == type(go.Scatter()): + if tr["mode"] == "lines": + tr["line_color"] = color + else: + tr["marker_color"] = color + elif type(tr) == type(go.Bar()): + tr["marker_color"] = color + + +def get_first_set_barmode(figs): + """ Get first bar mode from the figure that has it set and has bar traces. """ + + def _bar_mode_filter(f): + return ( + any([type(tr) == type(go.Bar()) for tr in f.data]) + and f.layout.barmode is not None + ) + + barmode = None + try: + barmode = [f.layout.barmode for f in filter(_bar_mode_filter, figs)][0] + except IndexError: + # if no figure sets barmode, then it is not set + pass + return barmode + + +def simple_overlay_example(): + df = test_data.aug_tips() + fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") + fig1 = px.histogram( + df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker" + ) + fig1.update_traces(marker_color="red") + fig = px_simple_overlay(fig0, fig1) + fig.update_layout(title="Simple figure combination") + return fig + + +if __name__ == "__main__": + fig_simple = simple_overlay_example() + fig_simple.show() diff --git a/proto/px_overlay/run_px_simple_overlay_demo b/proto/px_overlay/run_px_simple_overlay_demo new file mode 100755 index 00000000000..ea7891adde6 --- /dev/null +++ b/proto/px_overlay/run_px_simple_overlay_demo @@ -0,0 +1,2 @@ +#/bin/bash +PYTHONPATH=proto/px_combine python3 proto/px_combine/multilayered_data_test.py diff --git a/proto/px_overlay/secondary_y_test.py b/proto/px_overlay/secondary_y_test.py new file mode 100644 index 00000000000..ae7b49edc36 --- /dev/null +++ b/proto/px_overlay/secondary_y_test.py @@ -0,0 +1 @@ +# Put the second plot's y data on a secondary y diff --git a/proto/px_overlay/test_data.py b/proto/px_overlay/test_data.py new file mode 100644 index 00000000000..d45b4428e1b --- /dev/null +++ b/proto/px_overlay/test_data.py @@ -0,0 +1,66 @@ +import numpy as np +import plotly.express as px +import pandas as pd +from random import sample +from itertools import product +from functools import reduce + +# some made up data for demos + + +def words(remove_non_letters=True): + with open("/usr/share/dict/british-english", "r") as fd: + ws = fd.readlines() + return [w.strip().replace("'s", "") for w in ws] + + +def aug_tips(): + """ The tips data buf with "calories consumed". """ + tips = px.data.tips() + calories = np.clip( + tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100, + 100, + None, + ) + tips["calories_consumed"] = calories + return tips + + +def take(it, N): + return [next(it) for n in range(N)] + + +def multilayered_data( + N=20, d_divs=[2, 3, 4], rseed=np.random.RandomState(seed=2), rwalk=0.1 +): + """ + Generate data that can be faceted in len(d_divs) ways (e.g., row, col and + trace color/linestyle. etc.) + """ + ws = words() + tot_divs = np.cumprod(d_divs)[-1] + sample_i = np.arange(len(ws), dtype="int") + rseed.shuffle(sample_i) + names = iter(ws[i] for i in sample_i[: tot_divs + len(d_divs)]) + x = np.arange(N) + cat_div_names = [] + for div in d_divs: + # generate category names + div_names = [next(names) for _ in range(div)] + cat_div_names.append(div_names) + cat_names = [next(names) for _ in d_divs] + dfs = [] + for cat_combo in product(*cat_div_names): + d = dict() + for cat_name, c in zip(cat_names, cat_combo): + d[cat_name] = c + d["x"] = x + if rwalk is not None: + y = np.cumsum(rseed.standard_normal(N)) * rwalk + else: + y = rseed.standard_normal(N) + d["y"] = y + dfs.append(pd.DataFrame(d)) + # combine all the dicts + df = reduce(lambda a, b: pd.concat([a, b]), dfs, pd.DataFrame()) + return df diff --git a/proto/px_overlay/test_find_subplot_axes.py b/proto/px_overlay/test_find_subplot_axes.py new file mode 100644 index 00000000000..e9b763226bb --- /dev/null +++ b/proto/px_overlay/test_find_subplot_axes.py @@ -0,0 +1,40 @@ +from plotly.subplots import make_subplots +import px_overlay +import pytest + +fig = px_overlay.make_subplots_all_secondary_y(3, 4) +fig_no_sy = px_overlay.make_subplots(3, 4) +fig_custom = make_subplots( + rows=2, + cols=2, + specs=[[{}, {}], [{"colspan": 2}, None]], + subplot_titles=("First Subplot", "Second Subplot", "Third Subplot"), +) + + +def test_bad_row_col(): + with pytest.raises( + IndexError, + match=r"^Figure does not have a subplot at the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=False) + with pytest.raises( + IndexError, + match=r"^Figure does not have a subplot at the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=True) + + +def test_no_secondary_y(): + with pytest.raises( + IndexError, + match=r"^Could not find a secondary y-axis at the subplot in the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig_no_sy, 2, 2, secondary_y=True) + with pytest.raises( + IndexError, + match=r"^Could not find a y-axis at the subplot in the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig_custom, 2, 2, secondary_y=False) + axes = px_overlay.find_subplot_axes(fig_custom, 1, 2, secondary_y=False) + assert axes == ("xaxis2", "yaxis2")