diff --git a/postgres_copy/copy_to.py b/postgres_copy/copy_to.py index 3b32a11..2b198b3 100644 --- a/postgres_copy/copy_to.py +++ b/postgres_copy/copy_to.py @@ -10,6 +10,7 @@ from psycopg2.extensions import adapt from django.db.models.sql.query import Query from django.db.models.sql.compiler import SQLCompiler + logger = logging.getLogger(__name__) @@ -17,6 +18,7 @@ class SQLCopyToCompiler(SQLCompiler): """ Custom SQL compiler for creating a COPY TO query (postgres backend only). """ + def setup_query(self): """ Extend the default SQLCompiler.setup_query to add re-ordering of items in select. @@ -34,7 +36,7 @@ def setup_query(self): ) self.select.append(selection) - def execute_sql(self, csv_path_or_obj=None): + def execute_sql(self, csv_path_or_obj=None, client_encoding=None): """ Run the COPY TO query. """ @@ -46,6 +48,15 @@ def execute_sql(self, csv_path_or_obj=None): # use stdout to avoid file permission issues with connections[self.using].cursor() as c: + # set client encoding to adapted params + if client_encoding is None: + client_encoding = c.connection.encoding + elif client_encoding != c.connection.encoding: + raise ValueError('client_encoding does not match' + ' db encoding: {} != {}'.format(client_encoding, c.connection.encoding)) + for p in adapted_params: + if hasattr(p, 'encoding'): + p.encoding = client_encoding if client_encoding else p.encoding # compile the SELECT query select_sql = self.as_sql()[0] % adapted_params # then the COPY TO query @@ -86,6 +97,7 @@ class CopyToQuery(Query): """ Represents a "copy to" SQL query. """ + def get_compiler(self, using=None, connection=None): """ Return a SQLCopyToCompiler object. diff --git a/postgres_copy/managers.py b/postgres_copy/managers.py index 6df7244..fb188db 100644 --- a/postgres_copy/managers.py +++ b/postgres_copy/managers.py @@ -215,9 +215,13 @@ def to_csv(self, csv_path=None, *fields, **kwargs): escape_char = kwargs.get('escape', None) query.copy_to_escape = "ESCAPE '{}'".format(escape_char) if escape_char else "" + # Client encoding + client_encoding = kwargs.get('client_encoding', None) + # Run the query compiler = query.get_compiler(self.db, connection=connection) - data = compiler.execute_sql(csv_path) + + data = compiler.execute_sql(csv_path, client_encoding=client_encoding) # If no csv_path is provided, then the query will come back as a string. if csv_path is None: diff --git a/tests/data/special_names.csv b/tests/data/special_names.csv new file mode 100644 index 0000000..03ff4e0 --- /dev/null +++ b/tests/data/special_names.csv @@ -0,0 +1,5 @@ +NAME,NUMBER,DATE +ben,1,2012-01-01 +joe,2,2012-01-02 +jane,3,2012-01-03 +björn,4,2012-01-04 diff --git a/tests/tests.py b/tests/tests.py index 4db9563..895993d 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -39,6 +39,7 @@ def setUp(self): 'matching_headers.csv' ) self.secondarydb_path = os.path.join(self.data_dir, 'secondary_db.csv') + self.special_names_path = os.path.join(self.data_dir, 'special_names.csv') def tearDown(self): MockObject.objects.all().delete() @@ -227,6 +228,32 @@ def test_filter(self, _): [i['name'] for i in reader] ) + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_special_names(self, _): + self._load_objects(self.special_names_path) + MockObject.objects.filter(name="björn").to_csv(self.export_path) + reader = csv.DictReader(open(self.export_path, 'r')) + self.assertTrue( + ['BJÖRN'], + [i['name'] for i in reader] + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_special_names_encoding_error(self, _): + with self.assertRaises(ValueError): + self._load_objects(self.special_names_path) + MockObject.objects.filter(name="björn").to_csv(self.export_path, client_encoding='latin1') + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_filter_number(self, _): + # test filter by number (int), because adapted int parameters do not have an encoding attribute. + self._load_objects(self.name_path) + MockObject.objects.filter(number=3).to_csv(self.export_path) + reader = csv.DictReader(open(self.export_path, 'r')) + self.assertTrue( + ['JANE'], + [i['name'] for i in reader]) + @mock.patch("django.db.connection.validate_no_atomic_block") def test_fewer_fields(self, _): self._load_objects(self.name_path)