Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdks/python: enrich data with CloudSQL #34398

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,46 @@ def enrichment_with_bigtable():
| "Print" >> beam.Map(print))
# [END enrichment_with_bigtable]

def enrichment_with_cloudsql():
# [START enrichment_with_cloudsql]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter
import os

project_id = 'apache-beam-testing'
region_id = 'us-east1'
instance_id = 'beam-test'
table_id = 'cloudsql-enrichment-test'
database_id = 'test-database'
database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER")
database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD")
row_key = 'product_id'

data = [
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2),
]

cloudsql_handler = CloudSQLEnrichmentHandler(
project_id=project_id,
region_id=region_id,
instance_id=instance_id,
table_id=table_id,
database_type_adapter=DatabaseTypeAdapter.POSTGRESQL,
database_id=database_id,
database_user=database_user,
database_password=database_password,
row_key=row_key
)
with beam.Pipeline() as p:
_ = (
p
| "Create" >> beam.Create(data)
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql_handler)
| "Print" >> beam.Map(print))
# [END enrichment_with_cloudsql]

def enrichment_with_vertex_ai():
# [START enrichment_with_vertex_ai]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def validate_enrichment_with_bigtable():
[END enrichment_with_bigtable]'''.splitlines()[1:-1]
return expected

def validate_enrichment_with_cloudsql():
expected = '''[START enrichment_with_cloudsql]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
[END enrichment_with_cloudsql]'''.splitlines()[1:-1]
return expected

def validate_enrichment_with_vertex_ai():
expected = '''[START enrichment_with_vertex_ai]
Expand Down Expand Up @@ -68,6 +75,13 @@ def test_enrichment_with_bigtable(self, mock_stdout):
expected = validate_enrichment_with_bigtable()
self.assertEqual(output, expected)

def test_enrichment_with_cloudsql(self, mock_stdout):
from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_cloudsql
enrichment_with_cloudsql()
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_cloudsql()
self.assertEqual(output, expected)

def test_enrichment_with_vertex_ai(self, mock_stdout):
enrichment_with_vertex_ai()
output = mock_stdout.getvalue().splitlines()
Expand Down
169 changes: 169 additions & 0 deletions sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from collections.abc import Callable
from enum import Enum
from typing import Any
from typing import Optional

from google.cloud.sql.connector import Connector

import apache_beam as beam
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel

__all__ = [
'CloudSQLEnrichmentHandler',
]

# RowKeyFn takes beam.Row and returns tuple of (key_id, key_value).
RowKeyFn = Callable[[beam.Row], tuple[str]]

_LOGGER = logging.getLogger(__name__)


class DatabaseTypeAdapter(Enum):
POSTGRESQL = "pg8000"
MYSQL = "pymysql"
SQLSERVER = "pytds"


class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
"""A handler for :class:`apache_beam.transforms.enrichment.Enrichment`
transform to interact with Google Cloud SQL databases.

Args:
project_id (str): GCP project-id of the Cloud SQL instance.
region_id (str): GCP region-id of the Cloud SQL instance.
instance_id (str): GCP instance-id of the Cloud SQL instance.
database_type_adapter (DatabaseTypeAdapter): The type of database adapter to use.
Supported adapters are: POSTGRESQL (pg8000), MYSQL (pymysql), and SQLSERVER (pytds).
database_id (str): The id of the database to connect to.
database_user (str): The username for connecting to the database.
database_password (str): The password for connecting to the database.
table_id (str): The name of the table to query.
row_key (str): Field name from the input `beam.Row` object to use as
identifier for database querying.
row_key_fn: A lambda function that returns a string key from the
input row. Used to build/extract the identifier for the database query.
exception_level: A `enum.Enum` value from
``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel``
to set the level when no matching record is found from the database query.
Defaults to ``ExceptionLevel.WARN``.
"""
def __init__(
self,
region_id: str,
project_id: str,
instance_id: str,
database_type_adapter: DatabaseTypeAdapter,
database_id: str,
database_user: str,
database_password: str,
table_id: str,
row_key: str = "",
*,
row_key_fn: Optional[RowKeyFn] = None,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
):
self._project_id = project_id
self._region_id = region_id
self._instance_id = instance_id
self._database_type_adapter = database_type_adapter
self._database_id = database_id
self._database_user = database_user
self._database_password = database_password
self._table_id = table_id
self._row_key = row_key
self._row_key_fn = row_key_fn
self._exception_level = exception_level
if ((not self._row_key_fn and not self._row_key) or
bool(self._row_key_fn and self._row_key)):
raise ValueError(
"Please specify exactly one of `row_key` or a lambda "
"function with `row_key_fn` to extract the row key "
"from the input row.")

def __enter__(self):
"""Connect to the the Cloud SQL instance."""
self.connector = Connector()
self.client = self.connector.connect(
f"{self._project_id}:{self._region_id}:{self._instance_id}",
driver=self._database_type_adapter.value,
db=self._database_id,
user=self._database_user,
password=self._database_password,
)
self.cursor = self.client.cursor()

def __call__(self, request: beam.Row, *args, **kwargs):
"""
Executes a query to the Cloud SQL instance and returns
a `Tuple` of request and response.

Args:
request: the input `beam.Row` to enrich.
"""
response_dict: dict[str, Any] = {}
row_key_str: str = ""

try:
if self._row_key_fn:
self._row_key, row_key = self._row_key_fn(request)
else:
request_dict = request._asdict()
row_key_str = str(request_dict[self._row_key])
row_key = row_key_str

query = f"SELECT * FROM {self._table_id} WHERE {self._row_key} = %s"
self.cursor.execute(query, (row_key, ))
result = self.cursor.fetchone()

if result:
columns = [col[0] for col in self.cursor.description]
for i, value in enumerate(result):
response_dict[columns[i]] = value
elif self._exception_level == ExceptionLevel.WARN:
_LOGGER.warning(
'No matching record found for row_key: %s in table: %s',
row_key_str,
self._table_id)
elif self._exception_level == ExceptionLevel.RAISE:
raise ValueError(
'No matching record found for row_key: %s in table: %s' %
(row_key_str, self._table_id))
except KeyError:
raise KeyError('row_key %s not found in input PCollection.' % row_key_str)
except Exception as e:
raise e

return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean the instantiated Cloud SQL client."""
self.cursor.close()
self.client.close()
self.connector.close()
self.cursor, self.client, self.connector = None, None, None

def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with row key since it is unique to
a request made to the Cloud SQL instance."""
if self._row_key_fn:
id, value = self._row_key_fn(request)
return f"{id}: {value}"
return f"{self._row_key}: {request._asdict()[self._row_key]}"
Loading
Loading