Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Dash's create_callback_id to get the callback id #163

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Allow to define a custom user management via the `after_logged_in` method #156

### Changed
- Updated the `public_callback` to work in more cases #163

## [2.3.0] - 2024-03-18
### Added
- OIDCAuth allows to authenticate via OIDC
Expand Down
31 changes: 20 additions & 11 deletions dash_auth/public_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect
import logging
import os

from dash import Dash, callback
from dash._callback import GLOBAL_CALLBACK_MAP
from dash import get_app
from dash import Dash, Output, callback, get_app
from dash._callback import handle_grouped_callback_args
from dash._grouping import flatten_grouping
from dash._utils import create_callback_id
from werkzeug.routing import Map, MapAdapter, Rule


Expand Down Expand Up @@ -70,20 +71,28 @@ def public_callback(*callback_args, **callback_kwargs):
def decorator(func):

wrapped_func = callback(*callback_args, **callback_kwargs)(func)
callback_id = next(
(
k for k, v in GLOBAL_CALLBACK_MAP.items()
if inspect.getsource(v["callback"]) == inspect.getsource(func)
),
None,
output, inputs, _, _, _ = handle_grouped_callback_args(
callback_args, callback_kwargs
)
if isinstance(output, Output):
# Insert callback with scalar (non-multi) Output
output = output
has_output = True
else:
# Insert callback as multi Output
output = flatten_grouping(output)
has_output = len(output) > 0

callback_id = create_callback_id(
output, inputs, no_output=not has_output
)
try:
app = get_app()
app.server.config[PUBLIC_CALLBACKS] = (
get_public_callbacks(app) + [callback_id]
)
except Exception:
print(
logging.info(
"Could not set up the public callback as the Dash object "
"has not yet been instantiated."
)
Expand Down
86 changes: 51 additions & 35 deletions tests/test_oidc_auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os
from unittest.mock import patch

import requests
from dash import Dash, Input, Output, dcc, html
from flask import redirect

from dash_auth import (
protected_callback,
OIDCAuth,
)
from dash_auth import OIDCAuth, protected_callback


def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
Expand All @@ -17,7 +13,9 @@ def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):

def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs):
base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1]
return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong")
return redirect(
f"{base_url}?error=Unauthorized&error_description=something went wrong"
)


def valid_authorize_access_token(*args, **kwargs):
Expand All @@ -27,18 +25,26 @@ def valid_authorize_access_token(*args, **kwargs):
}


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
html.Div(id="output2"),
html.Div("static", id="output3"),
html.Div("static", id="output4"),
html.Div("not static", id="output5"),
])
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
html.Div(id="output2"),
html.Div("static", id="output3"),
html.Div("static", id="output4"),
html.Div("not static", id="output5"),
]
)

@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
Expand Down Expand Up @@ -101,13 +107,15 @@ def update_output5(new_value):
dash_br.wait_for_text_to_equal("#output5", "initial value")


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
invalid_authorize_redirect,
)
def test_oa002_oidc_auth_login_fail(dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output")
])
app.layout = html.Div(
[dcc.Input(id="input", value="initial value"), html.Div(id="output")]
)

@app.callback(Output("output", "children"), Input("input", "value"))
def update_output(new_value):
Expand All @@ -122,7 +130,7 @@ def update_output(new_value):
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url
base_url = dash_thread_server.url.rstrip("/")

def test_unauthorized(url):
r = requests.get(url)
Expand All @@ -133,17 +141,25 @@ def test_authorized(url):
assert requests.get(url).status_code == 200

test_unauthorized(base_url)
test_authorized(os.path.join(base_url, "public"))
test_authorized(base_url + "/public")


@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect",
valid_authorize_redirect,
)
@patch(
"authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token",
valid_authorize_access_token,
)
def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server):
app = Dash(__name__)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
])
app.layout = html.Div(
[
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
]
)

@app.callback(Output("output1", "children"), Input("input", "value"))
def update_output1(new_value):
Expand All @@ -168,21 +184,21 @@ def update_output1(new_value):
)

dash_thread_server(app)
base_url = dash_thread_server.url
base_url = dash_thread_server.url.rstrip("/")

assert requests.get(base_url).status_code == 400

# Login with IDP1
assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp1/login").status_code == 200

# Logout
assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200
assert requests.get(base_url + "/oidc/logout").status_code == 200

assert requests.get(base_url).status_code == 400

# Login with IDP2
assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp2/login").status_code == 200

dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login"))
dash_br.driver.get(base_url + "/oidc/idp2/login")
dash_br.driver.get(base_url)
dash_br.wait_for_text_to_equal("#output1", "initial value")
dash_br.wait_for_text_to_equal("#output1", "initial value")