diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index 74ad0836f7..0c1fcddbbb 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -19,6 +19,7 @@ Also provides a helper task to copy data into a Postgres table. """ +import contextlib import datetime import logging import re @@ -167,20 +168,20 @@ def exists(self, connection=None): if connection is None: connection = self.connect() connection.autocommit = True - cursor = connection.cursor() - try: - cursor.execute("""SELECT 1 FROM {marker_table} - WHERE update_id = %s - LIMIT 1""".format(marker_table=self.marker_table), - (self.update_id,) - ) - row = cursor.fetchone() - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: - row = None - else: - raise - return row is not None + with connection.cursor() as cursor: + try: + cursor.execute( + """SELECT 1 FROM {marker_table} + WHERE update_id = %s + LIMIT 1""".format(marker_table=self.marker_table), + (self.update_id,)) + row = cursor.fetchone() + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: + row = None + else: + raise + return row is not None def connect(self): """ @@ -201,30 +202,29 @@ def create_marker_table(self): Using a separate connection since the transaction might have to be reset. """ - connection = self.connect() - connection.autocommit = True - cursor = connection.cursor() - if self.use_db_timestamps: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP DEFAULT NOW()) - """.format(marker_table=self.marker_table) - else: - sql = """ CREATE TABLE {marker_table} ( - update_id TEXT PRIMARY KEY, - target_table TEXT, - inserted TIMESTAMP); - """.format(marker_table=self.marker_table) - - try: - cursor.execute(sql) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: - pass - else: - raise - connection.close() + with contextlib.closing(self.connect()) as connection: + connection.autocommit = True + with connection.cursor() as cursor: + if self.use_db_timestamps: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP DEFAULT NOW()) + """.format(marker_table=self.marker_table) + else: + sql = """ CREATE TABLE {marker_table} ( + update_id TEXT PRIMARY KEY, + target_table TEXT, + inserted TIMESTAMP); + """.format(marker_table=self.marker_table) + + try: + cursor.execute(sql) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: + pass + else: + raise def open(self, mode): raise NotImplementedError("Cannot open() PostgresTarget") @@ -285,7 +285,7 @@ def copy(self, cursor, file): elif len(self.columns[0]) == 2: column_names = [c[0] for c in self.columns] else: - raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) + raise ValueError('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) def run(self): @@ -299,52 +299,50 @@ def run(self): if not (self.table and self.columns): raise Exception("table and columns need to be specified") - connection = self.output().connect() - # transform all data generated by rows() using map_column and write data - # to a temporary file for import using postgres COPY - tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) - tmp_file = tempfile.TemporaryFile(dir=tmp_dir) - n = 0 - for row in self.rows(): - n += 1 - if n % 100000 == 0: - logger.info("Wrote %d lines", n) - rowstr = self.column_separator.join(self.map_column(val) for val in row) - rowstr += "\n" - tmp_file.write(rowstr.encode('utf-8')) - - logger.info("Done writing, importing at %s", datetime.datetime.now()) - tmp_file.seek(0) - - # attempt to copy the data into postgres - # if it fails because the target table doesn't exist - # try to create it by running self.create_table - for attempt in range(2): - try: - cursor = connection.cursor() - self.init_copy(connection) - self.copy(cursor, tmp_file) - self.post_copy(connection) - if self.enable_metadata_columns: - self.post_copy_metacolumns(cursor) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: - # if first attempt fails with "relation not found", try creating table - logger.info("Creating table %s", self.table) - connection.reset() - self.create_table(connection) - else: - raise - else: - break - - # mark as complete in same transaction - self.output().touch(connection) - - # commit and clean up - connection.commit() - connection.close() - tmp_file.close() + with contextlib.closing(self.output().connect()) as connection: + # transform all data generated by rows() using map_column and + # write data to a temporary file for import using postgres COPY + tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None) + with tempfile.TemporaryFile(dir=tmp_dir) as tmp_file: + n = 0 + for row in self.rows(): + n += 1 + if n % 100000 == 0: + logger.info("Wrote %d lines", n) + rowstr = self.column_separator.join(self.map_column(val) for val in row) + rowstr += "\n" + tmp_file.write(rowstr.encode('utf-8')) + + logger.info( + "Done writing, importing at %s", datetime.datetime.now()) + tmp_file.seek(0) + + with connection: + # attempt to copy the data into postgres + # if it fails because the target table doesn't exist + # try to create it by running self.create_table + for attempt in range(2): + try: + with connection.cursor() as cursor: + self.init_copy(connection) + self.copy(cursor, tmp_file) + self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + except psycopg2.ProgrammingError as e: + if e.pgcode == psycopg2.errorcodes.\ + UNDEFINED_TABLE and attempt == 0: + # if first attempt fails with "relation not found", try creating table + logger.info("Creating table %s", self.table) + connection.reset() + self.create_table(connection) + else: + raise + else: + break + + # mark as complete in same transaction + self.output().touch(connection) class PostgresQuery(rdbms.Query): diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index f08a330f84..c0cd313a88 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -26,12 +26,48 @@ def datetime_to_epoch(dt): return td.days * 86400 + td.seconds + td.microseconds / 1E6 -class MockPostgresCursor(mock.Mock): +class MockContextManager(mock.Mock): + + def __init__(self, *args, **kwargs): + super(MockContextManager, self).__init__(*args, **kwargs) + self.context_counter = 0 + self.all_context_counter = 0 + + def __enter__(self): + self.context_counter += 1 + self.all_context_counter += 1 + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.context_counter -= 1 + + def _get_child_mock(self, **kwargs): + """Child mocks will be instances of super.""" + return mock.Mock(**kwargs) + + +class MockPostgresConnection(MockContextManager): + def __init__(self, existing_update_ids, *args, **kwargs): + super(MockPostgresConnection, self).__init__(*args, **kwargs) + self.existing = existing_update_ids + self.is_open = False + self.was_open = 0 + + def cursor(self): + self.is_open = True + self.was_open = True + return MockPostgresCursor(existing_update_ids=self.existing) + + def close(self): + self.is_open = False + + +class MockPostgresCursor(MockContextManager): """ Keeps state to simulate executing SELECT queries and fetching results. """ - def __init__(self, existing_update_ids): - super(MockPostgresCursor, self).__init__() + def __init__(self, existing_update_ids, *args, **kwargs): + super(MockPostgresCursor, self).__init__(*args, **kwargs) self.existing = existing_update_ids def execute(self, query, params): @@ -82,6 +118,44 @@ def test_bulk_complete(self, mock_connect): ])) self.assertFalse(task.complete()) + @mock.patch('psycopg2.connect') + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + def test_cleanup_on_error(self, mock_rows, mock_connect): + """ + Test cleanup behavior of CopyToTable in case of an error. + + When an error occured while the connection is open, it should be + closed again so that subsequent tasks do not fail due to the unclosed + connection. + """ + task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15)) + + mock_connection = MockPostgresConnection([task.task_id]) + mock_connect.return_value = mock_connection + mock_cursor = MockPostgresCursor([task.task_id]) + + original_cursor = mock_connection.cursor + + def get_mock_cursor(): + original_cursor() + return mock_cursor + + mock_connection.cursor = mock.MagicMock(side_effect=get_mock_cursor) + + task = DummyPostgresImporter(date=datetime.datetime(2021, 4, 15)) + task.columns = [(42,)] # inject defect + + with self.assertRaisesRegex(ValueError, "columns"): + task.run() + + self.assertEqual(mock_connection.context_counter, 0) + self.assertTrue(mock_connection.all_context_counter) + self.assertFalse(mock_connection.is_open) + self.assertTrue(mock_connection.was_open) + + self.assertEqual(mock_cursor.context_counter, 0) + self.assertTrue(mock_cursor.all_context_counter) + class DummyPostgresQuery(luigi.contrib.postgres.PostgresQuery): date = luigi.DateParameter()