Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix clientside callback #162

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions dash_auth/public_routes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import inspect
import os

from dash import Dash, callback
from dash import Dash, callback, get_app
from dash._callback import GLOBAL_CALLBACK_MAP
from dash import get_app
from werkzeug.routing import Map, MapAdapter, Rule


DASH_PUBLIC_ASSETS_EXTENSIONS = "js,css"
BASE_PUBLIC_ROUTES = [
f"/assets/<path:path>.{ext}"
Expand Down Expand Up @@ -68,20 +66,21 @@ def public_callback(*callback_args, **callback_kwargs):
"""

def decorator(func):

wrapped_func = callback(*callback_args, **callback_kwargs)(func)
callback_id = next(
(
k for k, v in GLOBAL_CALLBACK_MAP.items()
if inspect.getsource(v["callback"]) == inspect.getsource(func)
k
for k, v in GLOBAL_CALLBACK_MAP.items()
if "callback" in v
and inspect.getsource(v["callback"]) == inspect.getsource(func)
),
None,
)
try:
app = get_app()
app.server.config[PUBLIC_CALLBACKS] = (
get_public_callbacks(app) + [callback_id]
)
app.server.config[PUBLIC_CALLBACKS] = get_public_callbacks(app) + [
callback_id
]
except Exception:
print(
"Could not set up the public callback as the Dash object "
Expand Down
84 changes: 50 additions & 34 deletions tests/test_oidc_auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os
from unittest.mock import patch

import requests
from dash import Dash, Input, Output, dcc, html
from flask import redirect

from dash_auth import (
protected_callback,
OIDCAuth,
)
from dash_auth import OIDCAuth, protected_callback


def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
Expand All @@ -17,7 +13,9 @@ def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):

def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs):
base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1]
return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong")
return redirect(
f"{base_url}?error=Unauthorized&error_description=something went wrong"
)


def valid_authorize_access_token(*args, **kwargs):
Expand All @@ -27,18 +25,26 @@ def valid_authorize_access_token(*args, **kwargs):
}


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
html.Div(id="output2"),
html.Div("static", id="output3"),
html.Div("static", id="output4"),
html.Div("not static", id="output5"),
])
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
html.Div(id="output2"),
html.Div("static", id="output3"),
html.Div("static", id="output4"),
html.Div("not static", id="output5"),
]
)

@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
Expand Down Expand Up @@ -101,13 +107,15 @@ def update_output5(new_value):
dash_br.wait_for_text_to_equal("#output5", "initial value")


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
invalid_authorize_redirect,
)
def test_oa002_oidc_auth_login_fail(dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output")
])
app.layout = html.Div(
[dcc.Input(id="input", value="initial value"), html.Div(id="output")]
)

@app.callback(Output("output", "children"), Input("input", "value"))
def update_output(new_value):
Expand All @@ -122,7 +130,7 @@ def update_output(new_value):
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url
base_url = dash_thread_server.url.rstrip("/")

def test_unauthorized(url):
r = requests.get(url)
Expand All @@ -133,17 +141,25 @@ def test_authorized(url):
assert requests.get(url).status_code == 200

test_unauthorized(base_url)
test_authorized(os.path.join(base_url, "public"))
test_authorized(base_url + "/public")


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
])
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
]
)

@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
Expand All @@ -168,21 +184,21 @@ def update_output1(new_value):
)

dash_thread_server(app)
base_url = dash_thread_server.url
base_url = dash_thread_server.url.rstrip("/")

assert requests.get(base_url).status_code == 400

# Login with IDP1
assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp1/login").status_code == 200

# Logout
assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200
assert requests.get(base_url + "/oidc/logout").status_code == 200

assert requests.get(base_url).status_code == 400

# Login with IDP2
assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp2/login").status_code == 200

dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login"))
dash_br.driver.get(base_url + "/oidc/idp2/login")
dash_br.driver.get(base_url)
dash_br.wait_for_text_to_equal("#output1", "initial value")