Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor named parameters feature to allow extensibility #148

Merged
merged 6 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{% if parameters %}
<h3>Query parameters</h3>
<div class="query-parameters">
{% for param in parameters %}
{{ param.form_control }}
{% endfor %}
</div>
<input
class="btn"
type="submit"
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
/>
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,7 @@ <h2 style="margin-top: 0.5em">Unverified SQL</h2>
{% if query_results %}
<p>↓ <a href="#save-dashboard">Save this dashboard</a> | <a href="{{ request.path }}">Remove all queries</a></p>
{% endif %}
{% if parameter_values %}
<h3>Query parameters</h3>
<div class="query-parameters">
{% for name, value in parameter_values %}
<label for="qp{{ forloop.counter }}">{{ name }}</label>
<input type="text" id="qp{{ forloop.counter }}" name="{{ name }}" value="{{ value }}">
{% endfor %}
</div>
<input
class="btn"
type="submit"
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
/>
{% endif %}
{% include "django_sql_dashboard/_parameters.html" %}
{% for result in query_results %}
{% include result.templates with result=result %}
{% endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,7 @@ <h1>{% if dashboard.title %}{{ dashboard.title }}{% else %}{{ dashboard.slug }}{
</p>

<form action="{{ request.path }}" method="GET">
{% if parameter_values %}
<h3>Query parameters</h3>
<div class="query-parameters">
{% for name, value in parameter_values %}
<label for="qp{{ forloop.counter }}">{{ name }}</label>
<input type="text" id="qp{{ forloop.counter }}" name="{{ name }}" value="{{ value }}">
{% endfor %}
</div>
<input
class="btn"
type="submit"
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
/>
{% endif %}
{% include "django_sql_dashboard/_parameters.html" %}
{% for result in query_results %}
{% include result.templates with result=result %}
{% endfor %}
Expand Down
93 changes: 79 additions & 14 deletions django_sql_dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from collections import namedtuple

from django.core import signing
from django.conf import settings
from django.utils.html import escape
from django.utils.safestring import mark_safe

SQL_SALT = "django_sql_dashboard:query"

Expand Down Expand Up @@ -58,20 +61,6 @@ def displayable_rows(rows):
return fixed


_named_parameters_re = re.compile(r"\%\(([^\)]+)\)s")


def extract_named_parameters(sql):
params = _named_parameters_re.findall(sql)
# Validation step: after removing params, are there
# any single `%` symbols that will confuse psycopg2?
without_params = _named_parameters_re.sub("", sql)
without_double_percents = without_params.replace("%%", "")
if "%" in without_double_percents:
raise ValueError(r"Found a single % character")
return params


def check_for_base64_upgrade(queries):
if not queries:
return
Expand Down Expand Up @@ -117,3 +106,79 @@ def apply_sort(sql, sort_column, is_desc=False):
else:
sql = "select * from ({}) as results".format(sql)
return sql + ' order by "{}"{}'.format(sort_column, " desc" if is_desc else "")


class Parameter:
extract_re = re.compile(r"\%\(([^\)]+)\)s")

def __init__(self, name, default_value=""):
self.name = name
self.default_value = self.get_sanitized(default_value, for_default=True)

def ensure_consistency(self, previous):
if self.name != previous.name:
raise ValueError("Invalid name for parameter '%s': previously registered with name '%s'" % (self.name, previous.name))
if self.default_value != "" and self.default_value != previous.default_value:
raise ValueError("Invalid default value '%s' for parameter '%s': previously registered with default value '%s'" % (self.default_value, self.name, previous.default_value))

def get_sanitized(self, value, for_default=False):
if value is None or (for_default and value == "null"):
return None # represents DB null value
if not isinstance(value, str):
raise ValueError("Invalid %svalue for parameter '%s': '%s'" % ("default " if for_default else "", self.name, type(value).__name__))
return value

@property
def value(self):
return self._value if hasattr(self, "_value") else self.default_value

@value.setter
def value(self, new_value):
self._value = self.get_sanitized(new_value) if new_value != "" else self.default_value

def form_control(self):
return mark_safe(f"""<label for="qp_{escape(self.name)}">{escape(self.name)}</label>
<input type="text" id="qp_{escape(self.name)}" name="{escape(self.name)}" value="{escape(self.value) if self.value is not None else ""}">""")

@classmethod
def extract(cls, sql: str, value_sources: list[dict[str, str]], target: list=[]):
for found in cls.extract_re.findall(sql):
# Ensure 'found' is an iterable of capturing groups, even if there is only one capturing group in the regex
if isinstance(found, str):
found = [found]
new_param = cls(*found)

# Ensure parameters are added only once
previous_param = next((param for param in target if param.name == new_param.name), None)
if previous_param:
new_param.ensure_consistency(previous_param)
else:
target.append(new_param)

# Validation step: after removing params, are there
# any single `%` symbols that will confuse psycopg2?
without_params = cls.extract_re.sub("", sql)
without_double_percents = without_params.replace("%%", "")
if "%" in without_double_percents:
raise ValueError(r"Found a single % character")

# Read values form sources
for param in target:
for source in value_sources:
if param.name in source:
param.value = source[param.name]
break

return target

@classmethod
def execute(cls, cursor, sql: str, parameters: list=[]):
values = { param.name: param.value for param in parameters }
cursor.execute(sql, values)

PARAMETER_CLASS = getattr(settings, "DASHBOARD_PARAMETER_CLASS", Parameter)
if isinstance(PARAMETER_CLASS, str):
from importlib import import_module

[module_name, class_name] = PARAMETER_CLASS.rsplit('.', 1)
PARAMETER_CLASS = getattr(import_module(module_name), class_name)
28 changes: 11 additions & 17 deletions django_sql_dashboard/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
apply_sort,
check_for_base64_upgrade,
displayable_rows,
extract_named_parameters,
PARAMETER_CLASS,
postgresql_reserved_words,
sign_sql,
unsign_sql,
Expand Down Expand Up @@ -192,10 +192,7 @@ def _dashboard_index(
sql_query_parameter_errors = []
for sql in sql_queries:
try:
extracted = extract_named_parameters(sql)
for p in extracted:
if p not in parameters:
parameters.append(p)
PARAMETER_CLASS.extract(sql, value_sources=[request.POST, request.GET], target=parameters)
sql_query_parameter_errors.append(False)
except ValueError as e:
if "%" in sql:
Expand All @@ -205,9 +202,9 @@ def _dashboard_index(
else:
sql_query_parameter_errors.append(str(e))
parameter_values = {
parameter: request.POST.get(parameter, request.GET.get(parameter, ""))
parameter.name: parameter.value
for parameter in parameters
if parameter != "sql"
if parameter.name != "sql"
}
extra_qs = "&{}".format(urlencode(parameter_values)) if parameter_values else ""
results_index = -1
Expand Down Expand Up @@ -250,7 +247,7 @@ def _dashboard_index(
# Running a SELECT prevents future SET TRANSACTION READ WRITE:
cursor.execute("SELECT 1;")
cursor.fetchall()
cursor.execute(sql, parameter_values)
PARAMETER_CLASS.execute(cursor, sql, parameters)
try:
rows = list(cursor.fetchmany(row_limit + 1))
except ProgrammingError as e:
Expand Down Expand Up @@ -303,12 +300,12 @@ def _dashboard_index(
if dashboard and dashboard.title:
html_title = dashboard.title

# Add named parameter values, if any exist
# Add named parameter values to the page title, when they are distinct from the default values
provided_values = {
key: value for key, value in parameter_values.items() if value.strip()
param.name: param.value for param in parameters if param.value != param.default_value
}
if provided_values:
if len(provided_values) == 1:
if len(parameters) == 1:
html_title += ": {}".format(list(provided_values.values())[0])
else:
html_title += ": {}".format(
Expand Down Expand Up @@ -343,7 +340,7 @@ def _dashboard_index(
"user_can_execute_sql": user_can_execute_sql,
"user_can_export_data": getattr(settings, "DASHBOARD_ENABLE_FULL_EXPORT", None)
and user_can_execute_sql,
"parameter_values": parameter_values.items(),
"parameters": parameters,
"too_long_so_use_post": too_long_so_use_post,
"saved_dashboards": saved_dashboards,
}
Expand Down Expand Up @@ -410,10 +407,7 @@ def export_sql_results(request):
assert format in ("csv", "tsv")
sqls = request.POST.getlist("sql")
sql = sqls[int(sql_index)]
parameter_values = {
parameter: request.POST.get(parameter, "")
for parameter in extract_named_parameters(sql)
}
parameters = PARAMETER_CLASS.extract(sql, value_sources=[request.POST])
alias = getattr(settings, "DASHBOARD_DB_ALIAS", "dashboard")
# Decide on filename
sql_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()[:6]
Expand Down Expand Up @@ -443,7 +437,7 @@ def read_and_flush():

def rows():
try:
cursor.execute(sql, parameter_values)
PARAMETER_CLASS.execute(cursor, sql, parameters)
done_header = False
while True:
records = cursor.fetchmany(size=2000)
Expand Down
12 changes: 6 additions & 6 deletions test_project/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def test_parameter_form(admin_client, dashboard_db):
html = response.content.decode("utf-8")
# Form should have three form fields
for fragment in (
'<label for="qp1">foo</label>',
'<input type="text" id="qp1" name="foo" value="">',
'<label for="qp2">bar</label>',
'<input type="text" id="qp2" name="bar" value="">',
'<label for="qp3">baz</label>',
'<input type="text" id="qp3" name="baz" value="">',
'<label for="qp_foo">foo</label>',
'<input type="text" id="qp_foo" name="foo" value="">',
'<label for="qp_bar">bar</label>',
'<input type="text" id="qp_bar" name="bar" value="">',
'<label for="qp_baz">baz</label>',
'<input type="text" id="qp_baz" name="baz" value="">',
):
assert fragment in html

Expand Down