diff --git a/django_sql_dashboard/templates/django_sql_dashboard/_parameters.html b/django_sql_dashboard/templates/django_sql_dashboard/_parameters.html new file mode 100644 index 0000000..4cdd797 --- /dev/null +++ b/django_sql_dashboard/templates/django_sql_dashboard/_parameters.html @@ -0,0 +1,13 @@ +{% if parameters %} +

Query parameters

+
+{% for param in parameters %} + {{ param.form_control }} +{% endfor %} +
+ +{% endif %} diff --git a/django_sql_dashboard/templates/django_sql_dashboard/dashboard.html b/django_sql_dashboard/templates/django_sql_dashboard/dashboard.html index f2638fd..22a40a3 100644 --- a/django_sql_dashboard/templates/django_sql_dashboard/dashboard.html +++ b/django_sql_dashboard/templates/django_sql_dashboard/dashboard.html @@ -28,20 +28,7 @@

Unverified SQL

{% if query_results %}

Save this dashboard | Remove all queries

{% endif %} - {% if parameter_values %} -

Query parameters

-
- {% for name, value in parameter_values %} - - - {% endfor %} -
- - {% endif %} + {% include "django_sql_dashboard/_parameters.html" %} {% for result in query_results %} {% include result.templates with result=result %} {% endfor %} diff --git a/django_sql_dashboard/templates/django_sql_dashboard/saved_dashboard.html b/django_sql_dashboard/templates/django_sql_dashboard/saved_dashboard.html index ae0ff03..4d624c9 100644 --- a/django_sql_dashboard/templates/django_sql_dashboard/saved_dashboard.html +++ b/django_sql_dashboard/templates/django_sql_dashboard/saved_dashboard.html @@ -30,20 +30,7 @@

{% if dashboard.title %}{{ dashboard.title }}{% else %}{{ dashboard.slug }}{

- {% if parameter_values %} -

Query parameters

-
- {% for name, value in parameter_values %} - - - {% endfor %} -
- - {% endif %} + {% include "django_sql_dashboard/_parameters.html" %} {% for result in query_results %} {% include result.templates with result=result %} {% endfor %} diff --git a/django_sql_dashboard/utils.py b/django_sql_dashboard/utils.py index bdd197f..f2b9d76 100644 --- a/django_sql_dashboard/utils.py +++ b/django_sql_dashboard/utils.py @@ -5,6 +5,9 @@ from collections import namedtuple from django.core import signing +from django.conf import settings +from django.utils.html import escape +from django.utils.safestring import mark_safe SQL_SALT = "django_sql_dashboard:query" @@ -58,20 +61,6 @@ def displayable_rows(rows): return fixed -_named_parameters_re = re.compile(r"\%\(([^\)]+)\)s") - - -def extract_named_parameters(sql): - params = _named_parameters_re.findall(sql) - # Validation step: after removing params, are there - # any single `%` symbols that will confuse psycopg2? - without_params = _named_parameters_re.sub("", sql) - without_double_percents = without_params.replace("%%", "") - if "%" in without_double_percents: - raise ValueError(r"Found a single % character") - return params - - def check_for_base64_upgrade(queries): if not queries: return @@ -117,3 +106,79 @@ def apply_sort(sql, sort_column, is_desc=False): else: sql = "select * from ({}) as results".format(sql) return sql + ' order by "{}"{}'.format(sort_column, " desc" if is_desc else "") + + +class Parameter: + extract_re = re.compile(r"\%\(([^\)]+)\)s") + + def __init__(self, name, default_value=""): + self.name = name + self.default_value = self.get_sanitized(default_value, for_default=True) + + def ensure_consistency(self, previous): + if self.name != previous.name: + raise ValueError("Invalid name for parameter '%s': previously registered with name '%s'" % (self.name, previous.name)) + if self.default_value != "" and self.default_value != previous.default_value: + raise ValueError("Invalid default value '%s' for parameter '%s': previously registered with default value '%s'" % (self.default_value, self.name, previous.default_value)) + + def get_sanitized(self, value, for_default=False): + if value is None or (for_default and value == "null"): + return None # represents DB null value + if not isinstance(value, str): + raise ValueError("Invalid %svalue for parameter '%s': '%s'" % ("default " if for_default else "", self.name, type(value).__name__)) + return value + + @property + def value(self): + return self._value if hasattr(self, "_value") else self.default_value + + @value.setter + def value(self, new_value): + self._value = self.get_sanitized(new_value) if new_value != "" else self.default_value + + def form_control(self): + return mark_safe(f""" +""") + + @classmethod + def extract(cls, sql: str, value_sources: list[dict[str, str]], target: list=[]): + for found in cls.extract_re.findall(sql): + # Ensure 'found' is an iterable of capturing groups, even if there is only one capturing group in the regex + if isinstance(found, str): + found = [found] + new_param = cls(*found) + + # Ensure parameters are added only once + previous_param = next((param for param in target if param.name == new_param.name), None) + if previous_param: + new_param.ensure_consistency(previous_param) + else: + target.append(new_param) + + # Validation step: after removing params, are there + # any single `%` symbols that will confuse psycopg2? + without_params = cls.extract_re.sub("", sql) + without_double_percents = without_params.replace("%%", "") + if "%" in without_double_percents: + raise ValueError(r"Found a single % character") + + # Read values form sources + for param in target: + for source in value_sources: + if param.name in source: + param.value = source[param.name] + break + + return target + + @classmethod + def execute(cls, cursor, sql: str, parameters: list=[]): + values = { param.name: param.value for param in parameters } + cursor.execute(sql, values) + +PARAMETER_CLASS = getattr(settings, "DASHBOARD_PARAMETER_CLASS", Parameter) +if isinstance(PARAMETER_CLASS, str): + from importlib import import_module + + [module_name, class_name] = PARAMETER_CLASS.rsplit('.', 1) + PARAMETER_CLASS = getattr(import_module(module_name), class_name) diff --git a/django_sql_dashboard/views.py b/django_sql_dashboard/views.py index c51bcef..870123c 100644 --- a/django_sql_dashboard/views.py +++ b/django_sql_dashboard/views.py @@ -25,7 +25,7 @@ apply_sort, check_for_base64_upgrade, displayable_rows, - extract_named_parameters, + PARAMETER_CLASS, postgresql_reserved_words, sign_sql, unsign_sql, @@ -192,10 +192,7 @@ def _dashboard_index( sql_query_parameter_errors = [] for sql in sql_queries: try: - extracted = extract_named_parameters(sql) - for p in extracted: - if p not in parameters: - parameters.append(p) + PARAMETER_CLASS.extract(sql, value_sources=[request.POST, request.GET], target=parameters) sql_query_parameter_errors.append(False) except ValueError as e: if "%" in sql: @@ -205,9 +202,9 @@ def _dashboard_index( else: sql_query_parameter_errors.append(str(e)) parameter_values = { - parameter: request.POST.get(parameter, request.GET.get(parameter, "")) + parameter.name: parameter.value for parameter in parameters - if parameter != "sql" + if parameter.name != "sql" } extra_qs = "&{}".format(urlencode(parameter_values)) if parameter_values else "" results_index = -1 @@ -250,7 +247,7 @@ def _dashboard_index( # Running a SELECT prevents future SET TRANSACTION READ WRITE: cursor.execute("SELECT 1;") cursor.fetchall() - cursor.execute(sql, parameter_values) + PARAMETER_CLASS.execute(cursor, sql, parameters) try: rows = list(cursor.fetchmany(row_limit + 1)) except ProgrammingError as e: @@ -303,12 +300,12 @@ def _dashboard_index( if dashboard and dashboard.title: html_title = dashboard.title - # Add named parameter values, if any exist + # Add named parameter values to the page title, when they are distinct from the default values provided_values = { - key: value for key, value in parameter_values.items() if value.strip() + param.name: param.value for param in parameters if param.value != param.default_value } if provided_values: - if len(provided_values) == 1: + if len(parameters) == 1: html_title += ": {}".format(list(provided_values.values())[0]) else: html_title += ": {}".format( @@ -343,7 +340,7 @@ def _dashboard_index( "user_can_execute_sql": user_can_execute_sql, "user_can_export_data": getattr(settings, "DASHBOARD_ENABLE_FULL_EXPORT", None) and user_can_execute_sql, - "parameter_values": parameter_values.items(), + "parameters": parameters, "too_long_so_use_post": too_long_so_use_post, "saved_dashboards": saved_dashboards, } @@ -410,10 +407,7 @@ def export_sql_results(request): assert format in ("csv", "tsv") sqls = request.POST.getlist("sql") sql = sqls[int(sql_index)] - parameter_values = { - parameter: request.POST.get(parameter, "") - for parameter in extract_named_parameters(sql) - } + parameters = PARAMETER_CLASS.extract(sql, value_sources=[request.POST]) alias = getattr(settings, "DASHBOARD_DB_ALIAS", "dashboard") # Decide on filename sql_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()[:6] @@ -443,7 +437,7 @@ def read_and_flush(): def rows(): try: - cursor.execute(sql, parameter_values) + PARAMETER_CLASS.execute(cursor, sql, parameters) done_header = False while True: records = cursor.fetchmany(size=2000) diff --git a/test_project/test_parameters.py b/test_project/test_parameters.py index 8bfc26e..f06f4a0 100644 --- a/test_project/test_parameters.py +++ b/test_project/test_parameters.py @@ -26,12 +26,12 @@ def test_parameter_form(admin_client, dashboard_db): html = response.content.decode("utf-8") # Form should have three form fields for fragment in ( - '', - '', - '', - '', - '', - '', + '', + '', + '', + '', + '', + '', ): assert fragment in html