Skip to content

Commit da53bc5

Browse files
authored
Add GCS authentication for service accounts (#315)
1 parent 1259545 commit da53bc5

File tree

8 files changed

+251
-74
lines changed

8 files changed

+251
-74
lines changed

STYLE_GUIDE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Streaming uses the [yapf](https://github.com/google/yapf) formatter for general
2727
(see section 2.2). These checks can also be run manually via:
2828

2929
```
30-
pre-commit run yapf --all-files # for yahp
30+
pre-commit run yapf --all-files # for yapf
3131
pre-commit run isort --all-files # for isort
3232
```
3333

docs/source/how_to_guides/configure_cloud_storage_credentials.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,23 @@ export S3_ENDPOINT_URL='https://<accountid>.r2.cloudflarestorage.com'
9797

9898
For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Google Cloud Storage](https://mcli.docs.mosaicml.com/en/latest/secrets/gcp.html) MCLI doc on how to configure the cloud provider credentials.
9999

100-
### Others
100+
101+
### GCP Service Account Credentials Mounted as Environment Variables
102+
103+
Users must set their GCP `account credentials` to point to their credentials file in the run environment.
104+
105+
````{tabs}
106+
```{code-tab} py
107+
import os
108+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'KEY_FILE'
109+
```
110+
111+
```{code-tab} sh
112+
export GOOGLE_APPLICATION_CREDENTIALS='KEY_FILE'
113+
```
114+
````
115+
116+
### GCP User Auth Credentials Mounted as Environment Variables
101117

102118
Streaming dataset supports [GCP user credentials](https://cloud.google.com/storage/docs/authentication#user_accounts) or [HMAC keys for User account](https://cloud.google.com/storage/docs/authentication/hmackeys). Users must set their GCP `user access key` and GCP `user access secret` in the run environment.
103119

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
install_requires = [
4545
'boto3>=1.21.45,<2',
4646
'Brotli>=1.0.9',
47+
'google-cloud-storage>=2.9.0',
4748
'matplotlib>=3.5.2,<4',
4849
'paramiko>=2.11.0,<4',
4950
'python-snappy>=0.6.1,<1',

streaming/base/storage/download.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,33 @@ def download_from_gcs(remote: str, local: str) -> None:
158158
remote (str): Remote path (GCS).
159159
local (str): Local path (local filesystem).
160160
"""
161-
import boto3
162-
from boto3.s3.transfer import TransferConfig
163-
from botocore.exceptions import ClientError
164-
165161
obj = urllib.parse.urlparse(remote)
166162
if obj.scheme != 'gs':
167163
raise ValueError(
168164
f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}')
169165

166+
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
167+
_gcs_with_service_account(local, obj)
168+
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
169+
_gcs_with_hmac(remote, local, obj)
170+
else:
171+
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
172+
f'service level accounts or GCS_KEY and GCS_SECRET needs to be ' +
173+
f'set for HMAC authentication')
174+
175+
176+
def _gcs_with_hmac(remote: str, local: str, obj: urllib.parse.ParseResult) -> None:
177+
"""Download a file from remote GCS to local using user level credentials.
178+
179+
Args:
180+
remote (str): Remote path (GCS).
181+
local (str): Local path (local filesystem).
182+
obj (ParseResult): ParseResult object of remote.
183+
"""
184+
import boto3
185+
from boto3.s3.transfer import TransferConfig
186+
from botocore.exceptions import ClientError
187+
170188
# Create a new session per thread
171189
session = boto3.session.Session()
172190
# Create a resource client using a thread's session object
@@ -190,6 +208,22 @@ def download_from_gcs(remote: str, local: str) -> None:
190208
raise
191209

192210

211+
def _gcs_with_service_account(local: str, obj: urllib.parse.ParseResult) -> None:
212+
"""Download a file from remote GCS to local using service account credentials.
213+
214+
Args:
215+
local (str): Local path (local filesystem).
216+
obj (ParseResult): ParseResult object of remote path (GCS).
217+
"""
218+
from google.cloud.storage import Blob, Bucket, Client
219+
220+
service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
221+
gcs_client = Client.from_service_account_json(service_account_path)
222+
223+
blob = Blob(obj.path.lstrip('/'), Bucket(gcs_client, obj.netloc))
224+
blob.download_to_filename(local)
225+
226+
193227
def download_from_oci(remote: str, local: str) -> None:
194228
"""Download a file from remote OCI to local.
195229

streaming/base/storage/upload.py

+80-39
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shutil
99
import sys
1010
import urllib.parse
11+
from enum import Enum
1112
from tempfile import mkdtemp
1213
from typing import Any, Tuple, Union
1314

@@ -16,8 +17,12 @@
1617
from streaming.base.storage.download import BOTOCORE_CLIENT_ERROR_CODES
1718

1819
__all__ = [
19-
'CloudUploader', 'S3Uploader', 'GCSUploader', 'OCIUploader', 'AzureUploader',
20-
'AzureDataLakeUploader', 'LocalUploader'
20+
'CloudUploader',
21+
'S3Uploader',
22+
'GCSUploader',
23+
'OCIUploader',
24+
'AzureUploader',
25+
'LocalUploader',
2126
]
2227

2328
logger = logging.getLogger(__name__)
@@ -32,6 +37,11 @@
3237
}
3338

3439

40+
class GCSAuthentication(Enum):
41+
HMAC = 1
42+
SERVICE_ACCOUNT = 2
43+
44+
3545
class CloudUploader:
3646
"""Upload local files to a cloud storage."""
3747

@@ -84,10 +94,9 @@ def _validate(self, out: Union[str, Tuple[str, str]]) -> None:
8494
obj = urllib.parse.urlparse(out)
8595
else:
8696
if len(out) != 2:
87-
raise ValueError(''.join([
88-
f'Invalid `out` argument. It is either a string of local/remote directory ',
89-
'or a list of two strings with [local, remote].'
90-
]))
97+
raise ValueError(f'Invalid `out` argument. It is either a string of ' +
98+
f'local/remote directory or a list of two strings with ' +
99+
f'[local, remote].')
91100
obj = urllib.parse.urlparse(out[1])
92101
if obj.scheme not in UPLOADERS:
93102
raise ValueError(f'Invalid Cloud provider prefix: {obj.scheme}.')
@@ -183,6 +192,7 @@ def __init__(self,
183192

184193
import boto3
185194
from botocore.config import Config
195+
186196
config = Config()
187197
# Create a session and use it to make our client. Unlike Resources and Sessions,
188198
# clients are generally thread-safe.
@@ -261,19 +271,34 @@ def __init__(self,
261271
progress_bar: bool = False) -> None:
262272
super().__init__(out, keep_local, progress_bar)
263273

264-
import boto3
274+
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
275+
from google.cloud.storage import Client
276+
277+
service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
278+
self.gcs_client = Client.from_service_account_json(service_account_path)
279+
self.authentication = GCSAuthentication.SERVICE_ACCOUNT
280+
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
281+
import boto3
282+
283+
# Create a session and use it to make our client. Unlike Resources and Sessions,
284+
# clients are generally thread-safe.
285+
session = boto3.session.Session()
286+
self.gcs_client = session.client(
287+
's3',
288+
region_name='auto',
289+
endpoint_url='https://storage.googleapis.com',
290+
aws_access_key_id=os.environ['GCS_KEY'],
291+
aws_secret_access_key=os.environ['GCS_SECRET'],
292+
)
293+
self.authentication = GCSAuthentication.HMAC
294+
else:
295+
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
296+
f'service level accounts or GCS_KEY and GCS_SECRET needs to ' +
297+
f'be set for HMAC authentication')
265298

266-
# Create a session and use it to make our client. Unlike Resources and Sessions,
267-
# clients are generally thread-safe.
268-
session = boto3.session.Session()
269-
self.gcs_client = session.client('s3',
270-
region_name='auto',
271-
endpoint_url='https://storage.googleapis.com',
272-
aws_access_key_id=os.environ['GCS_KEY'],
273-
aws_secret_access_key=os.environ['GCS_SECRET'])
274299
self.check_bucket_exists(self.remote) # pyright: ignore
275300

276-
def upload_file(self, filename: str):
301+
def upload_file(self, filename: str) -> None:
277302
"""Upload file from local instance to Google Cloud Storage bucket.
278303
279304
Args:
@@ -283,21 +308,31 @@ def upload_file(self, filename: str):
283308
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
284309
obj = urllib.parse.urlparse(remote_filename)
285310
logger.debug(f'Uploading to {remote_filename}')
286-
file_size = os.stat(local_filename).st_size
287-
with tqdm.tqdm(total=file_size,
288-
unit='B',
289-
unit_scale=True,
290-
desc=f'Uploading to {remote_filename}',
291-
disable=(not self.progress_bar)) as pbar:
292-
self.gcs_client.upload_file(
293-
local_filename,
294-
obj.netloc,
295-
obj.path.lstrip('/'),
296-
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
297-
)
311+
312+
if self.authentication == GCSAuthentication.HMAC:
313+
file_size = os.stat(local_filename).st_size
314+
with tqdm.tqdm(
315+
total=file_size,
316+
unit='B',
317+
unit_scale=True,
318+
desc=f'Uploading to {remote_filename}',
319+
disable=(not self.progress_bar),
320+
) as pbar:
321+
self.gcs_client.upload_file(
322+
local_filename,
323+
obj.netloc,
324+
obj.path.lstrip('/'),
325+
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
326+
)
327+
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
328+
from google.cloud.storage import Blob, Bucket
329+
330+
blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc))
331+
blob.upload_from_filename(local_filename)
332+
298333
self.clear_local(local=local_filename)
299334

300-
def check_bucket_exists(self, remote: str):
335+
def check_bucket_exists(self, remote: str) -> None:
301336
"""Raise an exception if the bucket does not exist.
302337
303338
Args:
@@ -306,16 +341,20 @@ def check_bucket_exists(self, remote: str):
306341
Raises:
307342
error: Bucket does not exist.
308343
"""
309-
from botocore.exceptions import ClientError
310-
311344
bucket_name = urllib.parse.urlparse(remote).netloc
312-
try:
313-
self.gcs_client.head_bucket(Bucket=bucket_name)
314-
except ClientError as error:
315-
if error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES:
316-
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
317-
f'or check the bucket permission.',)
318-
raise error
345+
346+
if self.authentication == GCSAuthentication.HMAC:
347+
from botocore.exceptions import ClientError
348+
349+
try:
350+
self.gcs_client.head_bucket(Bucket=bucket_name)
351+
except ClientError as error:
352+
if (error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES):
353+
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
354+
f'or check the bucket permission.',)
355+
raise error
356+
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
357+
self.gcs_client.get_bucket(bucket_name)
319358

320359

321360
class OCIUploader(CloudUploader):
@@ -343,6 +382,7 @@ def __init__(self,
343382
super().__init__(out, keep_local, progress_bar)
344383

345384
import oci
385+
346386
config = oci.config.from_file()
347387
self.client = oci.object_storage.ObjectStorageClient(
348388
config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY)
@@ -430,7 +470,8 @@ def __init__(self,
430470
# clients are generally thread-safe.
431471
self.azure_service = BlobServiceClient(
432472
account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
433-
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'])
473+
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'],
474+
)
434475
self.check_bucket_exists(self.remote) # pyright: ignore
435476

436477
def upload_file(self, filename: str):

0 commit comments

Comments
 (0)