From f66534c6ae018d0cfddb7d09aa50cd1fcf6d36d1 Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 10:12:23 +0000 Subject: [PATCH 1/8] postgres.CopyToTable.copy(): Refine exception type Instead of a generic exception, raise a specific ValueError. The former is considered an anti-pattern because generic exceptions are harder to handle. --- luigi/contrib/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index 74ad0836f7..ab4658bbec 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -285,7 +285,7 @@ def copy(self, cursor, file): elif len(self.columns[0]) == 2: column_names = [c[0] for c in self.columns] else: - raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) + raise ValueError('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) def run(self): From 1212dd88d74847a2d312918f059e22a8eeb429b0 Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 10:17:07 +0000 Subject: [PATCH 2/8] postgres.CopyToTable: Robustize resources cleanup Make sure to rollback and close connections as well as to release cursors when an error occurs during CopyToTable.copy(). This can happen, for example, when rows have an invalid format or the method is overriden in a subclass. Tests are added, too (see DailyCopyToTableTest.test_cleanup_on_error()). --- luigi/contrib/postgres.py | 94 ++++++++++++++++++----------------- test/contrib/postgres_test.py | 79 +++++++++++++++++++++++++++-- 2 files changed, 124 insertions(+), 49 deletions(-) diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index ab4658bbec..1439c45e2f 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -299,52 +299,54 @@ def run(self): if not (self.table and self.columns): raise Exception("table and columns need to be specified") - connection = self.output().connect() - # transform all data generated by rows() using map_column and write data - # to a temporary file for import using postgres COPY - tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) - tmp_file = tempfile.TemporaryFile(dir=tmp_dir) - n = 0 - for row in self.rows(): - n += 1 - if n % 100000 == 0: - logger.info("Wrote %d lines", n) - rowstr = self.column_separator.join(self.map_column(val) for val in row) - rowstr += "\n" - tmp_file.write(rowstr.encode('utf-8')) - - logger.info("Done writing, importing at %s", datetime.datetime.now()) - tmp_file.seek(0) - - # attempt to copy the data into postgres - # if it fails because the target table doesn't exist - # try to create it by running self.create_table - for attempt in range(2): - try: - cursor = connection.cursor() - self.init_copy(connection) - self.copy(cursor, tmp_file) - self.post_copy(connection) - if self.enable_metadata_columns: - self.post_copy_metacolumns(cursor) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: - # if first attempt fails with "relation not found", try creating table - logger.info("Creating table %s", self.table) - connection.reset() - self.create_table(connection) - else: - raise - else: - break - - # mark as complete in same transaction - self.output().touch(connection) - - # commit and clean up - connection.commit() - connection.close() - tmp_file.close() + try: + connection = self.output().connect() + # transform all data generated by rows() using map_column and write data + # to a temporary file for import using postgres COPY + tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) + tmp_file = tempfile.TemporaryFile(dir=tmp_dir) + n = 0 + for row in self.rows(): + n += 1 + if n % 100000 == 0: + logger.info("Wrote %d lines", n) + rowstr = self.column_separator.join(self.map_column(val) for val in row) + rowstr += "\n" + tmp_file.write(rowstr.encode('utf-8')) + + logger.info("Done writing, importing at %s", datetime.datetime.now()) + tmp_file.seek(0) + + with connection: + # attempt to copy the data into postgres + # if it fails because the target table doesn't exist + # try to create it by running self.create_table + for attempt in range(2): + try: + with connection.cursor() as cursor: + self.init_copy(connection) + self.copy(cursor, tmp_file) + self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: + # if first attempt fails with "relation not found", try creating table + logger.info("Creating table %s", self.table) + connection.reset() + self.create_table(connection) + else: + raise + else: + break + + # mark as complete in same transaction + self.output().touch(connection) + finally: + if connection: + connection.close() + if tmp_file: + tmp_file.close() class PostgresQuery(rdbms.Query): diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index f08a330f84..38a3bf9c85 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -26,12 +26,47 @@ def datetime_to_epoch(dt): return td.days * 86400 + td.seconds + td.microseconds / 1E6 -class MockPostgresCursor(mock.Mock): +class MockContextManager(mock.Mock): + + def __init__(self, *args, **kwargs): + super(MockContextManager, self).__init__(*args, **kwargs) + self.context_counter = 0 + self.all_context_counter = 0 + + def __enter__(self): + self.context_counter += 1 + self.all_context_counter += 1 + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.context_counter -= 1 + + def _get_child_mock(self, /, **kwargs): + """Child mocks will be instances of super.""" + return mock.Mock(**kwargs) + + +class MockPostgresConnection(MockContextManager): + def __init__(self, existing_update_ids, *args, **kwargs): + super(MockPostgresConnection, self).__init__(*args, **kwargs) + self.existing = existing_update_ids + self.is_open = False + self.was_open = 0 + + def cursor(self): + self.is_open = True + self.was_open = True + return MockPostgresCursor(existing_update_ids=self.existing) + + def close(self): + self.is_open = False + + +class MockPostgresCursor(MockContextManager): """ Keeps state to simulate executing SELECT queries and fetching results. """ - def __init__(self, existing_update_ids): - super(MockPostgresCursor, self).__init__() + def __init__(self, existing_update_ids, *args, **kwargs): + super(MockPostgresCursor, self).__init__(*args, **kwargs) self.existing = existing_update_ids def execute(self, query, params): @@ -82,6 +117,44 @@ def test_bulk_complete(self, mock_connect): ])) self.assertFalse(task.complete()) + @mock.patch('psycopg2.connect') + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + def test_cleanup_on_error(self, mock_rows, mock_connect): + """ + Test cleanup behavior of CopyToTable in case of an error. + + When an error occured while the connection is open, it should be + closed again so that subsequent tasks do not fail due to the unclosed + connection. + """ + task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15)) + + mock_connection = MockPostgresConnection([task.task_id]) + mock_connect.return_value = mock_connection + mock_cursor = MockPostgresCursor([task.task_id]) + + original_cursor = mock_connection.cursor + def get_mock_cursor(): + original_cursor() + return mock_cursor + + mock_connection.cursor = mock.MagicMock(side_effect=get_mock_cursor) + + task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15)) + task.columns = [(42,)] # inject defect + + #import pdb; pdb.set_trace() + with self.assertRaisesRegex(ValueError, "columns"): + task.run() + + self.assertEqual(mock_connection.context_counter, 0) + self.assertTrue(mock_connection.all_context_counter) + self.assertFalse(mock_connection.is_open) + self.assertTrue(mock_connection.was_open) + + self.assertEqual(mock_cursor.context_counter, 0) + self.assertTrue(mock_cursor.all_context_counter) + class DummyPostgresQuery(luigi.contrib.postgres.PostgresQuery): date = luigi.DateParameter() From e0ba49cd44ba929d7452dbe0f0798bd89ed16938 Mon Sep 17 00:00:00 2001 From: Christoph Thiede <38782922+LinqLover@users.noreply.github.com> Date: Thu, 15 Apr 2021 12:28:04 +0200 Subject: [PATCH 3/8] Remove dangling debug statement --- test/contrib/postgres_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index 38a3bf9c85..81bf79ecd8 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -143,7 +143,6 @@ def get_mock_cursor(): task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15)) task.columns = [(42,)] # inject defect - #import pdb; pdb.set_trace() with self.assertRaisesRegex(ValueError, "columns"): task.run() From 8a3a9496d4331d2ea24c3025d34ba80294153a17 Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 10:33:41 +0000 Subject: [PATCH 4/8] Make MockContextManager backward-compatible (py38) --- test/contrib/postgres_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index 81bf79ecd8..e4000e7edc 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -40,7 +40,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): self.context_counter -= 1 - def _get_child_mock(self, /, **kwargs): + def _get_child_mock(self, **kwargs): """Child mocks will be instances of super.""" return mock.Mock(**kwargs) From d0f7589c9ed731c89a2ed06465b9e94b3621fc02 Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 10:38:45 +0000 Subject: [PATCH 5/8] Please linter --- test/contrib/postgres_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index e4000e7edc..6078c106b6 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -134,6 +134,7 @@ def test_cleanup_on_error(self, mock_rows, mock_connect): mock_cursor = MockPostgresCursor([task.task_id]) original_cursor = mock_connection.cursor + def get_mock_cursor(): original_cursor() return mock_cursor From 66471b612de2a55b15b29dda0b278a0a9099f322 Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 20:31:36 +0000 Subject: [PATCH 6/8] Apply suggestions from @lallea re contextlib --- luigi/contrib/postgres.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index 1439c45e2f..bc69b3339a 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -19,6 +19,7 @@ Also provides a helper task to copy data into a Postgres table. """ +import contextlib import datetime import logging import re @@ -299,12 +300,11 @@ def run(self): if not (self.table and self.columns): raise Exception("table and columns need to be specified") - try: - connection = self.output().connect() - # transform all data generated by rows() using map_column and write data - # to a temporary file for import using postgres COPY + with contextlib.closing(self.output().connect()) as connection: + # transform all data generated by rows() using map_column and + # write data to a temporary file for import using postgres COPY tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) - tmp_file = tempfile.TemporaryFile(dir=tmp_dir) + with tempfile.TemporaryFile(dir=tmp_dir) as tmp_file: n = 0 for row in self.rows(): n += 1 @@ -314,7 +314,8 @@ def run(self): rowstr += "\n" tmp_file.write(rowstr.encode('utf-8')) - logger.info("Done writing, importing at %s", datetime.datetime.now()) + logger.info( + "Done writing, importing at %s", datetime.datetime.now()) tmp_file.seek(0) with connection: @@ -330,7 +331,8 @@ def run(self): if self.enable_metadata_columns: self.post_copy_metacolumns(cursor) except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: + if e.pgcode == psycopg2.errorcodes.\ + UNDEFINED_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table logger.info("Creating table %s", self.table) connection.reset() @@ -342,11 +344,6 @@ def run(self): # mark as complete in same transaction self.output().touch(connection) - finally: - if connection: - connection.close() - if tmp_file: - tmp_file.close() class PostgresQuery(rdbms.Query): From 4df506b9a912726099657737db46c476b0d359ad Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 20:32:13 +0000 Subject: [PATCH 7/8] Install additional context mgrs in PostgresTarget --- luigi/contrib/postgres.py | 139 +++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 70 deletions(-) diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index bc69b3339a..0c1fcddbbb 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -168,20 +168,20 @@ def exists(self, connection=None): if connection is None: connection = self.connect() connection.autocommit = True - cursor = connection.cursor() - try: - cursor.execute("""SELECT 1 FROM {marker_table} - WHERE update_id = %s - LIMIT 1""".format(marker_table=self.marker_table), - (self.update_id,) - ) - row = cursor.fetchone() - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: - row = None - else: - raise - return row is not None + with connection.cursor() as cursor: + try: + cursor.execute( + """SELECT 1 FROM {marker_table} + WHERE update_id = %s + LIMIT 1""".format(marker_table=self.marker_table), + (self.update_id,)) + row = cursor.fetchone() + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: + row = None + else: + raise + return row is not None def connect(self): """ @@ -202,30 +202,29 @@ def create_marker_table(self): Using a separate connection since the transaction might have to be reset. """ - connection = self.connect() - connection.autocommit = True - cursor = connection.cursor() - if self.use_db_timestamps: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP DEFAULT NOW()) - """.format(marker_table=self.marker_table) - else: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP); - """.format(marker_table=self.marker_table) - - try: - cursor.execute(sql) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: - pass - else: - raise - connection.close() + with contextlib.closing(self.connect()) as connection: + connection.autocommit = True + with connection.cursor() as cursor: + if self.use_db_timestamps: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP DEFAULT NOW()) + """.format(marker_table=self.marker_table) + else: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP); + """.format(marker_table=self.marker_table) + + try: + cursor.execute(sql) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: + pass + else: + raise def open(self, mode): raise NotImplementedError("Cannot open() PostgresTarget") @@ -305,45 +304,45 @@ def run(self): # write data to a temporary file for import using postgres COPY tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) with tempfile.TemporaryFile(dir=tmp_dir) as tmp_file: - n = 0 - for row in self.rows(): - n += 1 - if n % 100000 == 0: - logger.info("Wrote %d lines", n) - rowstr = self.column_separator.join(self.map_column(val) for val in row) - rowstr += "\n" - tmp_file.write(rowstr.encode('utf-8')) + n = 0 + for row in self.rows(): + n += 1 + if n % 100000 == 0: + logger.info("Wrote %d lines", n) + rowstr = self.column_separator.join(self.map_column(val) for val in row) + rowstr += "\n" + tmp_file.write(rowstr.encode('utf-8')) logger.info( "Done writing, importing at %s", datetime.datetime.now()) - tmp_file.seek(0) - - with connection: - # attempt to copy the data into postgres - # if it fails because the target table doesn't exist - # try to create it by running self.create_table - for attempt in range(2): - try: - with connection.cursor() as cursor: - self.init_copy(connection) - self.copy(cursor, tmp_file) - self.post_copy(connection) - if self.enable_metadata_columns: - self.post_copy_metacolumns(cursor) - except psycopg2.ProgrammingError as e: + tmp_file.seek(0) + + with connection: + # attempt to copy the data into postgres + # if it fails because the target table doesn't exist + # try to create it by running self.create_table + for attempt in range(2): + try: + with connection.cursor() as cursor: + self.init_copy(connection) + self.copy(cursor, tmp_file) + self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + except psycopg2.ProgrammingError as e: if e.pgcode == psycopg2.errorcodes.\ UNDEFINED_TABLE and attempt == 0: - # if first attempt fails with "relation not found", try creating table - logger.info("Creating table %s", self.table) - connection.reset() - self.create_table(connection) + # if first attempt fails with "relation not found", try creating table + logger.info("Creating table %s", self.table) + connection.reset() + self.create_table(connection) + else: + raise else: - raise - else: - break + break - # mark as complete in same transaction - self.output().touch(connection) + # mark as complete in same transaction + self.output().touch(connection) class PostgresQuery(rdbms.Query): From af11c2232b3cd073f0b410deb4c39fec8a42487c Mon Sep 17 00:00:00 2001 From: Christoph Thiede Date: Thu, 15 Apr 2021 20:32:30 +0000 Subject: [PATCH 8/8] Fix contextmgr impl of postgres_test.MockContextManager --- test/contrib/postgres_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index 6078c106b6..c0cd313a88 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -36,6 +36,7 @@ def __init__(self, *args, **kwargs): def __enter__(self): self.context_counter += 1 self.all_context_counter += 1 + return self def __exit__(self, exc_type, exc_value, exc_traceback): self.context_counter -= 1