From e417b899a9daa2edbca550ecfa7b69c10e408cc9 Mon Sep 17 00:00:00 2001 From: Oscar Villiger Date: Mon, 13 Mar 2023 16:36:42 +0100 Subject: [PATCH 1/3] Added client encoding to adapted SQL parameters. --- postgres_copy/copy_to.py | 4 ++++ tests/data/special_names.csv | 5 +++++ tests/tests.py | 11 +++++++++++ 3 files changed, 20 insertions(+) create mode 100644 tests/data/special_names.csv diff --git a/postgres_copy/copy_to.py b/postgres_copy/copy_to.py index 3b32a11..3777534 100644 --- a/postgres_copy/copy_to.py +++ b/postgres_copy/copy_to.py @@ -46,6 +46,10 @@ def execute_sql(self, csv_path_or_obj=None): # use stdout to avoid file permission issues with connections[self.using].cursor() as c: + # set client encoding to adapted params + client_encoding = c.connection.encoding + for p in adapted_params: + p.encoding = client_encoding if client_encoding else p.encoding # compile the SELECT query select_sql = self.as_sql()[0] % adapted_params # then the COPY TO query diff --git a/tests/data/special_names.csv b/tests/data/special_names.csv new file mode 100644 index 0000000..03ff4e0 --- /dev/null +++ b/tests/data/special_names.csv @@ -0,0 +1,5 @@ +NAME,NUMBER,DATE +ben,1,2012-01-01 +joe,2,2012-01-02 +jane,3,2012-01-03 +björn,4,2012-01-04 diff --git a/tests/tests.py b/tests/tests.py index 4db9563..2526f26 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -39,6 +39,7 @@ def setUp(self): 'matching_headers.csv' ) self.secondarydb_path = os.path.join(self.data_dir, 'secondary_db.csv') + self.special_names_path = os.path.join(self.data_dir, 'special_names.csv') def tearDown(self): MockObject.objects.all().delete() @@ -227,6 +228,16 @@ def test_filter(self, _): [i['name'] for i in reader] ) + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_special_names(self, _): + self._load_objects(self.special_names_path) + MockObject.objects.filter(name="björn").to_csv(self.export_path) + reader = csv.DictReader(open(self.export_path, 'r')) + self.assertTrue( + ['BJÖRN'], + [i['name'] for i in reader] + ) + @mock.patch("django.db.connection.validate_no_atomic_block") def test_fewer_fields(self, _): self._load_objects(self.name_path) From d288a866d992e597065963f5e86ab96e22285a6e Mon Sep 17 00:00:00 2001 From: Oscar Villiger Date: Mon, 13 Mar 2023 19:31:38 +0100 Subject: [PATCH 2/3] Added keyword argument client_encoding with default value None to execute_sql. It now raises ValueError if client_encoding does not match the database encoding. --- postgres_copy/copy_to.py | 11 +++++++++-- postgres_copy/managers.py | 6 +++++- tests/tests.py | 6 ++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/postgres_copy/copy_to.py b/postgres_copy/copy_to.py index 3777534..455c599 100644 --- a/postgres_copy/copy_to.py +++ b/postgres_copy/copy_to.py @@ -10,6 +10,7 @@ from psycopg2.extensions import adapt from django.db.models.sql.query import Query from django.db.models.sql.compiler import SQLCompiler + logger = logging.getLogger(__name__) @@ -17,6 +18,7 @@ class SQLCopyToCompiler(SQLCompiler): """ Custom SQL compiler for creating a COPY TO query (postgres backend only). """ + def setup_query(self): """ Extend the default SQLCompiler.setup_query to add re-ordering of items in select. @@ -34,7 +36,7 @@ def setup_query(self): ) self.select.append(selection) - def execute_sql(self, csv_path_or_obj=None): + def execute_sql(self, csv_path_or_obj=None, client_encoding=None): """ Run the COPY TO query. """ @@ -47,7 +49,11 @@ def execute_sql(self, csv_path_or_obj=None): # use stdout to avoid file permission issues with connections[self.using].cursor() as c: # set client encoding to adapted params - client_encoding = c.connection.encoding + if client_encoding is None: + client_encoding = c.connection.encoding + elif client_encoding != c.connection.encoding: + raise ValueError('client_encoding does not match' + ' db encoding: {} != {}'.format(client_encoding, c.connection.encoding)) for p in adapted_params: p.encoding = client_encoding if client_encoding else p.encoding # compile the SELECT query @@ -90,6 +96,7 @@ class CopyToQuery(Query): """ Represents a "copy to" SQL query. """ + def get_compiler(self, using=None, connection=None): """ Return a SQLCopyToCompiler object. diff --git a/postgres_copy/managers.py b/postgres_copy/managers.py index 6df7244..fb188db 100644 --- a/postgres_copy/managers.py +++ b/postgres_copy/managers.py @@ -215,9 +215,13 @@ def to_csv(self, csv_path=None, *fields, **kwargs): escape_char = kwargs.get('escape', None) query.copy_to_escape = "ESCAPE '{}'".format(escape_char) if escape_char else "" + # Client encoding + client_encoding = kwargs.get('client_encoding', None) + # Run the query compiler = query.get_compiler(self.db, connection=connection) - data = compiler.execute_sql(csv_path) + + data = compiler.execute_sql(csv_path, client_encoding=client_encoding) # If no csv_path is provided, then the query will come back as a string. if csv_path is None: diff --git a/tests/tests.py b/tests/tests.py index 2526f26..95cd23e 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -238,6 +238,12 @@ def test_filter_special_names(self, _): [i['name'] for i in reader] ) + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_special_names_encoding_error(self, _): + with self.assertRaises(ValueError): + self._load_objects(self.special_names_path) + MockObject.objects.filter(name="björn").to_csv(self.export_path, client_encoding='latin1') + @mock.patch("django.db.connection.validate_no_atomic_block") def test_fewer_fields(self, _): self._load_objects(self.name_path) From 01232beb42a49d98d92f1e6d256549c9bf9f504f Mon Sep 17 00:00:00 2001 From: Oscar Villiger Date: Tue, 14 Mar 2023 08:02:57 +0100 Subject: [PATCH 3/3] Added check of attribute `encoding` on adapted parameter before assigning. --- postgres_copy/copy_to.py | 3 ++- tests/tests.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/postgres_copy/copy_to.py b/postgres_copy/copy_to.py index 455c599..2b198b3 100644 --- a/postgres_copy/copy_to.py +++ b/postgres_copy/copy_to.py @@ -55,7 +55,8 @@ def execute_sql(self, csv_path_or_obj=None, client_encoding=None): raise ValueError('client_encoding does not match' ' db encoding: {} != {}'.format(client_encoding, c.connection.encoding)) for p in adapted_params: - p.encoding = client_encoding if client_encoding else p.encoding + if hasattr(p, 'encoding'): + p.encoding = client_encoding if client_encoding else p.encoding # compile the SELECT query select_sql = self.as_sql()[0] % adapted_params # then the COPY TO query diff --git a/tests/tests.py b/tests/tests.py index 95cd23e..895993d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -244,6 +244,16 @@ def test_filter_special_names_encoding_error(self, _): self._load_objects(self.special_names_path) MockObject.objects.filter(name="björn").to_csv(self.export_path, client_encoding='latin1') + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_number(self, _): + # test filter by number (int), because adapted int parameters do not have an encoding attribute. + self._load_objects(self.name_path) + MockObject.objects.filter(number=3).to_csv(self.export_path) + reader = csv.DictReader(open(self.export_path, 'r')) + self.assertTrue( + ['JANE'], + [i['name'] for i in reader]) + @mock.patch("django.db.connection.validate_no_atomic_block") def test_fewer_fields(self, _): self._load_objects(self.name_path)