diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py
index f2033cf359a..623e976dc2d 100644
--- a/packages/python/plotly/plotly/basedatatypes.py
+++ b/packages/python/plotly/plotly/basedatatypes.py
@@ -391,6 +391,9 @@ class is a subclass of both BaseFigure and widgets.DOMWidget.
         self._animation_duration_validator = animation.DurationValidator()
         self._animation_easing_validator = animation.EasingValidator()
 
+        # Space for auxiliary data
+        self._aux = dict()
+
         # Template
         # --------
         # ### Check for default template ###
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 7ad2fb4eb01..727ecd687a8 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -2057,6 +2057,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
 
     configure_axes(args, constructor, fig, orders)
     configure_animation_controls(args, constructor, fig)
+    # store args in figure metadata
+    fig._aux["px"] = dict(args=args)
     return fig
 
 
diff --git a/proto/px_overlay/README.md b/proto/px_overlay/README.md
new file mode 100644
index 00000000000..a64dc2ea733
--- /dev/null
+++ b/proto/px_overlay/README.md
@@ -0,0 +1,15 @@
+# `px.overlay` prototype
+
+This demonstrates one possible way of combining two figures into a single
+figure.
+
+To see an example, run (from the root of the `plotly.py` repo):
+
+```bash
+PYTHONPATH=proto/px_overlay python proto/px_overlay/multilayered_data_test.py
+```
+
+To see the code that does the overlaying, start with the `px_simple_overlay`
+function in `proto/px_overlay/px_overlay.py`. In this function there are a few
+comments marked with `TODO` that indicate places for improvement in the
+functionality.
diff --git a/proto/px_overlay/facet_col_wrap_test.py b/proto/px_overlay/facet_col_wrap_test.py
new file mode 100644
index 00000000000..ae41ae67c61
--- /dev/null
+++ b/proto/px_overlay/facet_col_wrap_test.py
@@ -0,0 +1,27 @@
+import plotly.express as px
+import test_data
+from px_combine import px_simple_combine
+
+df = test_data.multilayered_data(d_divs=[6, 3, 2], rwalk=0.1)
+last_cat = df.columns[2]
+last_cat_types = list(set(df[last_cat]))
+fig0 = px.line(
+    df.loc[df[last_cat] == last_cat_types[0]],
+    x="x",
+    y="y",
+    facet_col=df.columns[0],
+    facet_col_wrap=3,
+    color=df.columns[1],
+).update_layout(title="%s=%s" % (last_cat, last_cat_types[0]))
+fig1 = px.line(
+    df.loc[df[last_cat] == last_cat_types[1]],
+    x="x",
+    y="y",
+    facet_col=df.columns[0],
+    facet_col_wrap=3,
+    color=df.columns[1],
+).update_layout(title="%s=%s" % (last_cat, last_cat_types[1]))
+fig = px_simple_combine(fig0, fig1, fig1_secondary_y=True)
+fig0.show()
+fig1.show()
+fig.show()
diff --git a/proto/px_overlay/find_field.py b/proto/px_overlay/find_field.py
new file mode 100644
index 00000000000..c173ae05107
--- /dev/null
+++ b/proto/px_overlay/find_field.py
@@ -0,0 +1,43 @@
+import plotly.graph_objects as go
+from plotly import basedatatypes
+
+# Search down an object's composition tree and find fields with a given name
+
+
+def find_field(obj, field, basepath="", max_path_len=80, forbidden=["parent"]):
+    if obj is not None and len(basepath) < max_path_len:
+        for f in dir(obj):
+            joined_path = ".".join([basepath, f])
+            if f == field:
+                print(joined_path)
+            if (
+                (f not in forbidden)
+                and (not f.startswith("_"))
+                and (not f.endswith("_"))
+            ):
+                find_field(eval("obj.%s" % (f,)), field, joined_path)
+
+
+def find_all_xy_traces():
+    for field in dir(go):
+        call_str = "go.%s" % (field,)
+        call = eval(call_str)
+        try:
+            if issubclass(call, basedatatypes.BaseTraceType):
+                obj = call()
+                if "xaxis" in obj and "yaxis" in obj:
+                    yield (call_str)
+        except TypeError:
+            pass
+
+
+# s=go.Scatter()
+# s=go.Bar()
+# find_field(s,"color",basepath="scatter")
+# print()
+# find_field(s,"color",basepath="bar")
+
+for call_str in find_all_xy_traces():
+    call = eval(call_str)
+    find_field(call(), "color", basepath=call_str)
+    print()
diff --git a/proto/px_overlay/map_axis_pair_example.py b/proto/px_overlay/map_axis_pair_example.py
new file mode 100644
index 00000000000..a22c26136dd
--- /dev/null
+++ b/proto/px_overlay/map_axis_pair_example.py
@@ -0,0 +1,29 @@
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import px_overlay
+import pytest
+
+fig0 = px_overlay.make_subplots_all_secondary_y(3, 4)
+fig1 = px_overlay.make_subplots_all_secondary_y(4, 5)
+
+for dims, f in zip([(3, 4), (4, 5)], [fig0, fig1]):
+    for r, c in px_overlay.multi_index(*dims):
+        for sy in [False, True]:
+            f.add_trace(go.Scatter(x=[], y=[]), row=r + 1, col=c + 1, secondary_y=sy)
+
+fig0.add_annotation(row=2, col=3, text="hi", x=0.25, xref="x domain", y=3)
+fig0.add_annotation(
+    row=3, col=4, text="hi", x=0.25, xref="x domain", y=2, secondary_y=True
+)
+
+for an in fig0.layout.annotations:
+    oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]])
+    newaxpair = px_overlay.map_axis_pair(fig0, fig1, oldaxpair)
+    newan = go.layout.Annotation(an)
+    print(oldaxpair)
+    print(newaxpair)
+    newan["xref"], newan["yref"] = newaxpair
+    fig1.add_annotation(newan)
+
+fig0.show()
+fig1.show()
diff --git a/proto/px_overlay/multilayered_data_test.py b/proto/px_overlay/multilayered_data_test.py
new file mode 100644
index 00000000000..6f59c2dcfaa
--- /dev/null
+++ b/proto/px_overlay/multilayered_data_test.py
@@ -0,0 +1,56 @@
+import test_data
+import numpy as np
+import plotly.express as px
+from px_overlay import px_simple_overlay
+
+# Demonstrates px_overlay prototype.
+
+# Make some data that can be faceted by row, col and color, and split into 2
+# sets, which will go to the first and second figure respectively.
+df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1)
+
+# The titles of the figures use the last dimension in the data. The title is
+# formatted "column_name=column_value", so here we extract the column name.
+last_cat = df.columns[3]
+figs = []
+for px_call, last_cat_0 in zip([px.line, px.bar], list(set(df[last_cat]))):
+    # px_call is the chart type to make and last_cat_0 is the column_value for
+    # that figure which is used in forming the title.
+    df_slice = df.loc[df[last_cat] == last_cat_0]
+    fig = px_call(
+        df_slice,
+        x="x",
+        y="y",
+        facet_row=df.columns[0],
+        facet_col=df.columns[1],
+        color=df.columns[2],
+    )
+
+    fig.update_layout(title="%s=%s" % (last_cat, last_cat_0,))
+    figs.append(fig)
+
+# Add some annotations to make sure they are copied to the final figure properly
+figs[0].add_hline(y=1, row=1, col="all")
+figs[0].add_annotation(
+    x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo"
+)
+# Note that these annotations should be mapped to a secondary y axis (observe this
+# in the final figure by dragging their corresponding secondary y axes).
+figs[1].add_vline(x=10, row="all", col=2)
+figs[1].add_annotation(
+    x=0.5, y=0.35, xref="x domain", yref="y", row=1, col=2, text="budday"
+)
+# Set the bar modes for both to see that the first figure that the barmode for
+# the final figure will be taken from the figure that has bars.
+figs[0].layout.barmode = "group"
+figs[1].layout.barmode = "relative"
+
+# overlay the figures
+final_fig = px_simple_overlay(*figs, fig1_secondary_y=True)
+
+# Show the initial figures
+for fig in figs:
+    fig.show()
+
+# Show the final figure
+final_fig.show()
diff --git a/proto/px_overlay/px_overlay.py b/proto/px_overlay/px_overlay.py
new file mode 100644
index 00000000000..3ebe32330ed
--- /dev/null
+++ b/proto/px_overlay/px_overlay.py
@@ -0,0 +1,360 @@
+# Prototype for px.overlay
+# Combine 2 figures containing subplots
+# Run as
+# python px_overlay.py
+
+import plotly.express as px
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import test_data
+import json
+from itertools import product, cycle, chain
+from functools import reduce
+import re
+
+
+def multi_index(*kwargs):
+    return product(*[range(k) for k in kwargs])
+
+
+def extract_axes(layout):
+    ret = dict()
+    for k in dir(layout):
+        if k[1 : 1 + len("axis")] == "axis":
+            ret[k] = layout[k]
+    return ret
+
+
+def fig_grid_ref_shape(fig):
+    grid_ref = fig._validate_get_grid_ref()
+    return (len(grid_ref), len(grid_ref[0]))
+
+
+def fig_subplot_axes(fig, r, c):
+    grid_ref = fig._validate_get_grid_ref()
+    return [fig.layout[k] for k in grid_ref[r - 1][c - 1][0].layout_keys]
+
+
+def extract_axis_titles(fig):
+    """
+    Given figure created using make_subplots, with r rows and c columns, return
+    r titles from the x axes and y titles from the y axes.
+    """
+    grid_ref_shape = fig_grid_ref_shape(fig)
+    r_titles = [
+        fig_subplot_axes(fig, r + 1, 1)[1]["title"] for r in range(grid_ref_shape[0])
+    ]
+    c_titles = [
+        fig_subplot_axes(fig, 1, c + 1)[0]["title"] for c in range(grid_ref_shape[1])
+    ]
+    return (r_titles, c_titles)
+
+
+def make_subplots_all_secondary_y(rows, cols):
+    """
+    Get subplots like make_subplots but all also have secondary y-axes.
+    """
+    grid_ref_shape = [rows, cols]
+    specs = [
+        [dict(secondary_y=True) for __ in range(grid_ref_shape[1])]
+        for _ in range(grid_ref_shape[0])
+    ]
+    fig = make_subplots(*grid_ref_shape, specs=specs)
+    return fig
+
+
+def parse_axis_ref(ax):
+    """ Find the axis letter, optional number, and domain of axis. """
+    # TODO: can this be obtained via codegen?
+    pat = re.compile("([xy])(axis)?([0-9]*)( domain)?")
+    matches = pat.match(ax)
+    if matches is None:
+        raise ValueError('Axis "%s" cannot be parsed.' % (ax,))
+    return (matches[1], matches[3], matches[4])
+
+
+def norm_axis_ref(ax):
+    """ normalize ax so it is in the format: yaxis, yaxis2, xaxis7 etc. """
+    al, an, _ = parse_axis_ref(ax)
+    return al + "axis" + an
+
+
+def axis_pair_to_row_col(fig, axpair):
+    """
+    returns the row and column of the subplot having the axis pair and whether it is a
+    secondary y
+    """
+    if "paper" in axpair:
+        raise ValueError('Cannot find row and column of "paper" axis reference.')
+    naxpair = tuple([norm_axis_ref(ax) for ax in axpair])
+    nrows, ncols = fig_grid_ref_shape(fig)
+    row = None
+    col = None
+    for r, c in multi_index(nrows, ncols):
+        for sp in fig._grid_ref[r][c]:
+            if naxpair == sp.layout_keys:
+                row = r + 1
+                col = c + 1
+    if row is None or col is None:
+        raise ValueError("Could not find subplot containing axes (%s,%s)." % nax)
+    secondary_y = False
+    yax = naxpair[1]
+    if fig.layout[yax]["side"] == "right":
+        secondary_y = True
+    return (row, col, secondary_y)
+
+
+def find_subplot_axes(fig, row, col, secondary_y=False):
+    """
+    Returns 2-tuple containing (xaxis,yaxis) at specified row, col and secondary y-axis. 
+    """
+    nrows, ncols = fig_grid_ref_shape(fig)
+    try:
+        sps = fig._grid_ref[row - 1][col - 1]
+    except (IndexError, TypeError):
+        # IndexError if fig has _grid_ref but not requested row or column,
+        # TypeError if fig has no _grid_ref (it is None)
+        raise IndexError(
+            "Figure does not have a subplot at the requested row or column."
+        )
+
+    def _check_is_secondary_y(sp):
+        xax, yax = sp.layout_keys
+        # TODO: It may not be totally accurate to assume if an y-axis' "side" is
+        # "right" than it is a secondary y axis...
+        return fig.layout[yax]["side"] == "right"
+
+    # find the secondary y axis
+    err_msg = (
+        "Could not find a y-axis " "at the subplot in the requested row or column."
+    )
+    filter_fun = lambda sp: not _check_is_secondary_y(sp)
+    if secondary_y:
+        err_msg = (
+            "Could not find a secondary y-axis "
+            "at the subplot in the requested row or column."
+        )
+        filter_fun = _check_is_secondary_y
+    try:
+        sp = list(filter(filter_fun, sps))[0]
+    except (IndexError, TypeError):
+        # Catch IndexError if the list is empty, catch TypeError if sps isn't
+        # iterable (e.g., is None)
+        raise IndexError(err_msg)
+    return sp.layout_keys
+
+
+def map_axis_pair(
+    old_fig,
+    new_fig,
+    axpair,
+    new_row=None,
+    new_col=None,
+    new_secondary_y=None,
+    make_axis_ref=True,
+):
+    """
+    Find the axes on the new figure that will give the same subplot and
+    possibly secondary y axis as on the old figure. This can only
+    work if the axis pair is ("paper","paper") or the axis pair corresponds to a
+    subplot on the old figure the new figure has corresponding rows,
+    columns and secondary y-axes.
+    if make_axis_ref is True, axis is removed from the resulting strings, e.g., xaxis2 -> x2
+    """
+    if None in axpair:
+        raise ValueError("Cannot map axis whose value is None.")
+    if axpair == ("paper", "paper"):
+        return axpair
+    row, col, secondary_y = axis_pair_to_row_col(old_fig, axpair)
+    row = new_row if new_row is not None else row
+    col = new_col if new_col is not None else col
+    secondary_y = new_secondary_y if new_secondary_y is not None else secondary_y
+    newaxpair = find_subplot_axes(new_fig, row, col, secondary_y)
+    axpair_extras = [" domain" if ax.endswith("domain") else "" for ax in axpair]
+    newaxpair = tuple(ax + extra for ax, extra in zip(newaxpair, axpair_extras))
+    if make_axis_ref:
+        newaxpair = tuple(ax.replace("axis", "") for ax in newaxpair)
+    return newaxpair
+
+
+def map_annotation_like_obj_axis(oldfig, newfig, an, force_secondary_y=False):
+    """
+    Take an annotation-like object with xref and yref referring to axes in oldfig
+    and map them to axes in newfig. This makes it possible to map an annotation
+    to the same subplot row, column or secondary y in a new plot even if they do
+    not have matching subplots.
+    If force_secondary_y is True, attempt is made to map the annotation to a
+    secondary y axis in the new figure.
+    Returns the new annotation. Note that it has not been added to newfig, the
+    caller must then do this if it wants it added to newfig.
+    """
+    oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]])
+    newaxpair = map_axis_pair(
+        oldfig, newfig, oldaxpair, new_secondary_y=force_secondary_y
+    )
+    newan = an.__class__(an)
+    newan["xref"], newan["yref"] = newaxpair
+    return newan
+
+
+def px_simple_overlay(fig0, fig1, fig1_secondary_y=False):
+    """
+    Combines two figures by putting all the traces from fig0 and fig1 on a new
+    figure (fig). Then the annotation-like objects are copied to fig (i.e., the
+    titles are not copied).
+    The colors are reassigned so each trace has a unique color until all the
+    colors in the colorway are exhausted and then loops through the colorway to
+    assign additional colors (this is referred to as "reflowing" below).
+    In order to differentiate the traces in the legend, if fig0 or fig1 have
+    titles, they are prepended to the trace name.
+    If fig1_secondary_y is True, then the yaxes from fig1 are placed on
+    secondary y axes in the new figure.
+    """
+    if fig1_secondary_y and (
+        ("px" not in fig0._aux.keys()) or ("px" not in fig0._aux.keys())
+    ):
+        raise ValueError(
+            "To place fig1's traces on secondary y-axes, both figures must have "
+            "been made with Plotly Express."
+        )
+    grid_ref_shape = fig_grid_ref_shape(fig0)
+    if grid_ref_shape != fig_grid_ref_shape(fig1):
+        raise ValueError(
+            "Only two figures with the same subplot geometry can be overlayed."
+        )
+    # get colors for reflowing
+    colorway = fig0.layout.template.layout.colorway
+    specs = None
+    if fig1_secondary_y:
+        specs = [
+            [dict(secondary_y=True) for __ in range(grid_ref_shape[1])]
+            for _ in range(grid_ref_shape[0])
+        ]
+    # TODO: This needs to detect the start_cell of the input figures rather than
+    # assuming 'bottom-left', which is just the px default start_cell
+    fig = make_subplots(
+        *fig_grid_ref_shape(fig0), specs=specs, start_cell="bottom-left"
+    )
+    for r, c in multi_index(*fig_grid_ref_shape(fig)):
+        print("row,col", r + 1, c + 1)
+        for (tr, f), color in zip(
+            chain(
+                *[
+                    zip(f.select_traces(row=r + 1, col=c + 1), cycle([f]),)
+                    for f in [fig0, fig1]
+                ]
+            ),
+            # reflow the colors
+            cycle(colorway),
+        ):
+            title = f.layout.title.text
+            set_main_trace_color(tr, color)
+            # use figure title to differentiate the legend items
+            tr["name"] = "%s %s" % (title, tr["name"])
+            # TODO: argument to group legend items?
+            tr["legendgroup"] = None
+            fig.add_trace(
+                tr, row=r + 1, col=c + 1, secondary_y=(fig1_secondary_y and (f == fig1))
+            )
+    # TODO: How to preserve axis sizes when adding secondary y?
+
+    # Map the axes of the annotation-like objects to the new figure. Map the
+    # fig1 objects to the secondary-y if requested.
+    selectors = product(
+        [fig0, fig1],
+        [
+            go.Figure.select_annotations,
+            go.Figure.select_shapes,
+            go.Figure.select_layout_images,
+        ],
+    )
+    adders = product(
+        [(fig, False), (fig, fig1_secondary_y)],
+        [go.Figure.add_annotation, go.Figure.add_shape, go.Figure.add_layout_image],
+    )
+    for (oldfig, selector), ((newfig, secy), adder) in zip(selectors, adders):
+        for ann in selector(oldfig):
+            # TODO this function needs to eventually take into consideration the
+            # start_cell arguments of the figures involved in the mapping.
+            newann = map_annotation_like_obj_axis(
+                oldfig, newfig, ann, force_secondary_y=secy
+            )
+            adder(newfig, newann)
+
+    # fig.update_layout(fig0.layout)
+    # title will be wrong
+    fig.layout.title = None
+    # preserve bar mode
+    # if both figures have barmode set, the first is taken from the figure that
+    # has bars (so just the one from fig0 if both have bars), otherwise the set
+    # one is taken.
+    # TODO argument to force barmode? or the user can just update it after
+    fig.layout.barmode = get_first_set_barmode([fig0, fig1])
+    return fig
+
+
+def select_all_traces(figs):
+    traces = list(
+        reduce(
+            lambda a, b: a + b,
+            map(lambda t: list(go.Figure.select_traces(t)), figs),
+            [],
+        )
+    )
+    return traces
+
+
+def check_trace_type_xy(tr):
+    return ("xaxis" in tr) and ("yaxis" in tr)
+
+
+def check_figs_trace_types_xy(figs):
+    traces = select_all_traces(figs)
+    xy_traces = list(map(check_trace_type_xy, traces))
+    return xy_traces
+
+
+def set_main_trace_color(tr, color):
+    # Set the main color of a trace
+    if type(tr) == type(go.Scatter()):
+        if tr["mode"] == "lines":
+            tr["line_color"] = color
+        else:
+            tr["marker_color"] = color
+    elif type(tr) == type(go.Bar()):
+        tr["marker_color"] = color
+
+
+def get_first_set_barmode(figs):
+    """ Get first bar mode from the figure that has it set and has bar traces. """
+
+    def _bar_mode_filter(f):
+        return (
+            any([type(tr) == type(go.Bar()) for tr in f.data])
+            and f.layout.barmode is not None
+        )
+
+    barmode = None
+    try:
+        barmode = [f.layout.barmode for f in filter(_bar_mode_filter, figs)][0]
+    except IndexError:
+        # if no figure sets barmode, then it is not set
+        pass
+    return barmode
+
+
+def simple_overlay_example():
+    df = test_data.aug_tips()
+    fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker")
+    fig1 = px.histogram(
+        df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker"
+    )
+    fig1.update_traces(marker_color="red")
+    fig = px_simple_overlay(fig0, fig1)
+    fig.update_layout(title="Simple figure combination")
+    return fig
+
+
+if __name__ == "__main__":
+    fig_simple = simple_overlay_example()
+    fig_simple.show()
diff --git a/proto/px_overlay/run_px_simple_overlay_demo b/proto/px_overlay/run_px_simple_overlay_demo
new file mode 100755
index 00000000000..ea7891adde6
--- /dev/null
+++ b/proto/px_overlay/run_px_simple_overlay_demo
@@ -0,0 +1,2 @@
+#/bin/bash
+PYTHONPATH=proto/px_combine python3 proto/px_combine/multilayered_data_test.py
diff --git a/proto/px_overlay/secondary_y_test.py b/proto/px_overlay/secondary_y_test.py
new file mode 100644
index 00000000000..ae7b49edc36
--- /dev/null
+++ b/proto/px_overlay/secondary_y_test.py
@@ -0,0 +1 @@
+# Put the second plot's y data on a secondary y
diff --git a/proto/px_overlay/test_data.py b/proto/px_overlay/test_data.py
new file mode 100644
index 00000000000..d45b4428e1b
--- /dev/null
+++ b/proto/px_overlay/test_data.py
@@ -0,0 +1,66 @@
+import numpy as np
+import plotly.express as px
+import pandas as pd
+from random import sample
+from itertools import product
+from functools import reduce
+
+# some made up data for demos
+
+
+def words(remove_non_letters=True):
+    with open("/usr/share/dict/british-english", "r") as fd:
+        ws = fd.readlines()
+    return [w.strip().replace("'s", "") for w in ws]
+
+
+def aug_tips():
+    """ The tips data buf with "calories consumed". """
+    tips = px.data.tips()
+    calories = np.clip(
+        tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100,
+        100,
+        None,
+    )
+    tips["calories_consumed"] = calories
+    return tips
+
+
+def take(it, N):
+    return [next(it) for n in range(N)]
+
+
+def multilayered_data(
+    N=20, d_divs=[2, 3, 4], rseed=np.random.RandomState(seed=2), rwalk=0.1
+):
+    """
+    Generate data that can be faceted in len(d_divs) ways (e.g., row, col and
+    trace color/linestyle. etc.)
+    """
+    ws = words()
+    tot_divs = np.cumprod(d_divs)[-1]
+    sample_i = np.arange(len(ws), dtype="int")
+    rseed.shuffle(sample_i)
+    names = iter(ws[i] for i in sample_i[: tot_divs + len(d_divs)])
+    x = np.arange(N)
+    cat_div_names = []
+    for div in d_divs:
+        # generate category names
+        div_names = [next(names) for _ in range(div)]
+        cat_div_names.append(div_names)
+    cat_names = [next(names) for _ in d_divs]
+    dfs = []
+    for cat_combo in product(*cat_div_names):
+        d = dict()
+        for cat_name, c in zip(cat_names, cat_combo):
+            d[cat_name] = c
+        d["x"] = x
+        if rwalk is not None:
+            y = np.cumsum(rseed.standard_normal(N)) * rwalk
+        else:
+            y = rseed.standard_normal(N)
+        d["y"] = y
+        dfs.append(pd.DataFrame(d))
+    # combine all the dicts
+    df = reduce(lambda a, b: pd.concat([a, b]), dfs, pd.DataFrame())
+    return df
diff --git a/proto/px_overlay/test_find_subplot_axes.py b/proto/px_overlay/test_find_subplot_axes.py
new file mode 100644
index 00000000000..e9b763226bb
--- /dev/null
+++ b/proto/px_overlay/test_find_subplot_axes.py
@@ -0,0 +1,40 @@
+from plotly.subplots import make_subplots
+import px_overlay
+import pytest
+
+fig = px_overlay.make_subplots_all_secondary_y(3, 4)
+fig_no_sy = px_overlay.make_subplots(3, 4)
+fig_custom = make_subplots(
+    rows=2,
+    cols=2,
+    specs=[[{}, {}], [{"colspan": 2}, None]],
+    subplot_titles=("First Subplot", "Second Subplot", "Third Subplot"),
+)
+
+
+def test_bad_row_col():
+    with pytest.raises(
+        IndexError,
+        match=r"^Figure does not have a subplot at the requested row or column\.$",
+    ):
+        px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=False)
+    with pytest.raises(
+        IndexError,
+        match=r"^Figure does not have a subplot at the requested row or column\.$",
+    ):
+        px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=True)
+
+
+def test_no_secondary_y():
+    with pytest.raises(
+        IndexError,
+        match=r"^Could not find a secondary y-axis at the subplot in the requested row or column\.$",
+    ):
+        px_overlay.find_subplot_axes(fig_no_sy, 2, 2, secondary_y=True)
+    with pytest.raises(
+        IndexError,
+        match=r"^Could not find a y-axis at the subplot in the requested row or column\.$",
+    ):
+        px_overlay.find_subplot_axes(fig_custom, 2, 2, secondary_y=False)
+    axes = px_overlay.find_subplot_axes(fig_custom, 1, 2, secondary_y=False)
+    assert axes == ("xaxis2", "yaxis2")