8
8
import shutil
9
9
import sys
10
10
import urllib .parse
11
+ from enum import Enum
11
12
from tempfile import mkdtemp
12
13
from typing import Any , Tuple , Union
13
14
16
17
from streaming .base .storage .download import BOTOCORE_CLIENT_ERROR_CODES
17
18
18
19
__all__ = [
19
- 'CloudUploader' , 'S3Uploader' , 'GCSUploader' , 'OCIUploader' , 'AzureUploader' ,
20
- 'AzureDataLakeUploader' , 'LocalUploader'
20
+ 'CloudUploader' ,
21
+ 'S3Uploader' ,
22
+ 'GCSUploader' ,
23
+ 'OCIUploader' ,
24
+ 'AzureUploader' ,
25
+ 'LocalUploader' ,
21
26
]
22
27
23
28
logger = logging .getLogger (__name__ )
32
37
}
33
38
34
39
40
+ class GCSAuthentication (Enum ):
41
+ HMAC = 1
42
+ SERVICE_ACCOUNT = 2
43
+
44
+
35
45
class CloudUploader :
36
46
"""Upload local files to a cloud storage."""
37
47
@@ -84,10 +94,9 @@ def _validate(self, out: Union[str, Tuple[str, str]]) -> None:
84
94
obj = urllib .parse .urlparse (out )
85
95
else :
86
96
if len (out ) != 2 :
87
- raise ValueError ('' .join ([
88
- f'Invalid `out` argument. It is either a string of local/remote directory ' ,
89
- 'or a list of two strings with [local, remote].'
90
- ]))
97
+ raise ValueError (f'Invalid `out` argument. It is either a string of ' +
98
+ f'local/remote directory or a list of two strings with ' +
99
+ f'[local, remote].' )
91
100
obj = urllib .parse .urlparse (out [1 ])
92
101
if obj .scheme not in UPLOADERS :
93
102
raise ValueError (f'Invalid Cloud provider prefix: { obj .scheme } .' )
@@ -183,6 +192,7 @@ def __init__(self,
183
192
184
193
import boto3
185
194
from botocore .config import Config
195
+
186
196
config = Config ()
187
197
# Create a session and use it to make our client. Unlike Resources and Sessions,
188
198
# clients are generally thread-safe.
@@ -261,19 +271,34 @@ def __init__(self,
261
271
progress_bar : bool = False ) -> None :
262
272
super ().__init__ (out , keep_local , progress_bar )
263
273
264
- import boto3
274
+ if 'GOOGLE_APPLICATION_CREDENTIALS' in os .environ :
275
+ from google .cloud .storage import Client
276
+
277
+ service_account_path = os .environ ['GOOGLE_APPLICATION_CREDENTIALS' ]
278
+ self .gcs_client = Client .from_service_account_json (service_account_path )
279
+ self .authentication = GCSAuthentication .SERVICE_ACCOUNT
280
+ elif 'GCS_KEY' in os .environ and 'GCS_SECRET' in os .environ :
281
+ import boto3
282
+
283
+ # Create a session and use it to make our client. Unlike Resources and Sessions,
284
+ # clients are generally thread-safe.
285
+ session = boto3 .session .Session ()
286
+ self .gcs_client = session .client (
287
+ 's3' ,
288
+ region_name = 'auto' ,
289
+ endpoint_url = 'https://storage.googleapis.com' ,
290
+ aws_access_key_id = os .environ ['GCS_KEY' ],
291
+ aws_secret_access_key = os .environ ['GCS_SECRET' ],
292
+ )
293
+ self .authentication = GCSAuthentication .HMAC
294
+ else :
295
+ raise ValueError (f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
296
+ f'service level accounts or GCS_KEY and GCS_SECRET needs to ' +
297
+ f'be set for HMAC authentication' )
265
298
266
- # Create a session and use it to make our client. Unlike Resources and Sessions,
267
- # clients are generally thread-safe.
268
- session = boto3 .session .Session ()
269
- self .gcs_client = session .client ('s3' ,
270
- region_name = 'auto' ,
271
- endpoint_url = 'https://storage.googleapis.com' ,
272
- aws_access_key_id = os .environ ['GCS_KEY' ],
273
- aws_secret_access_key = os .environ ['GCS_SECRET' ])
274
299
self .check_bucket_exists (self .remote ) # pyright: ignore
275
300
276
- def upload_file (self , filename : str ):
301
+ def upload_file (self , filename : str ) -> None :
277
302
"""Upload file from local instance to Google Cloud Storage bucket.
278
303
279
304
Args:
@@ -283,21 +308,31 @@ def upload_file(self, filename: str):
283
308
remote_filename = os .path .join (self .remote , filename ) # pyright: ignore
284
309
obj = urllib .parse .urlparse (remote_filename )
285
310
logger .debug (f'Uploading to { remote_filename } ' )
286
- file_size = os .stat (local_filename ).st_size
287
- with tqdm .tqdm (total = file_size ,
288
- unit = 'B' ,
289
- unit_scale = True ,
290
- desc = f'Uploading to { remote_filename } ' ,
291
- disable = (not self .progress_bar )) as pbar :
292
- self .gcs_client .upload_file (
293
- local_filename ,
294
- obj .netloc ,
295
- obj .path .lstrip ('/' ),
296
- Callback = lambda bytes_transferred : pbar .update (bytes_transferred ),
297
- )
311
+
312
+ if self .authentication == GCSAuthentication .HMAC :
313
+ file_size = os .stat (local_filename ).st_size
314
+ with tqdm .tqdm (
315
+ total = file_size ,
316
+ unit = 'B' ,
317
+ unit_scale = True ,
318
+ desc = f'Uploading to { remote_filename } ' ,
319
+ disable = (not self .progress_bar ),
320
+ ) as pbar :
321
+ self .gcs_client .upload_file (
322
+ local_filename ,
323
+ obj .netloc ,
324
+ obj .path .lstrip ('/' ),
325
+ Callback = lambda bytes_transferred : pbar .update (bytes_transferred ),
326
+ )
327
+ elif self .authentication == GCSAuthentication .SERVICE_ACCOUNT :
328
+ from google .cloud .storage import Blob , Bucket
329
+
330
+ blob = Blob (obj .path .lstrip ('/' ), Bucket (self .gcs_client , obj .netloc ))
331
+ blob .upload_from_filename (local_filename )
332
+
298
333
self .clear_local (local = local_filename )
299
334
300
- def check_bucket_exists (self , remote : str ):
335
+ def check_bucket_exists (self , remote : str ) -> None :
301
336
"""Raise an exception if the bucket does not exist.
302
337
303
338
Args:
@@ -306,16 +341,20 @@ def check_bucket_exists(self, remote: str):
306
341
Raises:
307
342
error: Bucket does not exist.
308
343
"""
309
- from botocore .exceptions import ClientError
310
-
311
344
bucket_name = urllib .parse .urlparse (remote ).netloc
312
- try :
313
- self .gcs_client .head_bucket (Bucket = bucket_name )
314
- except ClientError as error :
315
- if error .response ['Error' ]['Code' ] == BOTOCORE_CLIENT_ERROR_CODES :
316
- error .args = (f'Either bucket `{ bucket_name } ` does not exist! ' +
317
- f'or check the bucket permission.' ,)
318
- raise error
345
+
346
+ if self .authentication == GCSAuthentication .HMAC :
347
+ from botocore .exceptions import ClientError
348
+
349
+ try :
350
+ self .gcs_client .head_bucket (Bucket = bucket_name )
351
+ except ClientError as error :
352
+ if (error .response ['Error' ]['Code' ] == BOTOCORE_CLIENT_ERROR_CODES ):
353
+ error .args = (f'Either bucket `{ bucket_name } ` does not exist! ' +
354
+ f'or check the bucket permission.' ,)
355
+ raise error
356
+ elif self .authentication == GCSAuthentication .SERVICE_ACCOUNT :
357
+ self .gcs_client .get_bucket (bucket_name )
319
358
320
359
321
360
class OCIUploader (CloudUploader ):
@@ -343,6 +382,7 @@ def __init__(self,
343
382
super ().__init__ (out , keep_local , progress_bar )
344
383
345
384
import oci
385
+
346
386
config = oci .config .from_file ()
347
387
self .client = oci .object_storage .ObjectStorageClient (
348
388
config = config , retry_strategy = oci .retry .DEFAULT_RETRY_STRATEGY )
@@ -430,7 +470,8 @@ def __init__(self,
430
470
# clients are generally thread-safe.
431
471
self .azure_service = BlobServiceClient (
432
472
account_url = f"https://{ os .environ ['AZURE_ACCOUNT_NAME' ]} .blob.core.windows.net" ,
433
- credential = os .environ ['AZURE_ACCOUNT_ACCESS_KEY' ])
473
+ credential = os .environ ['AZURE_ACCOUNT_ACCESS_KEY' ],
474
+ )
434
475
self .check_bucket_exists (self .remote ) # pyright: ignore
435
476
436
477
def upload_file (self , filename : str ):
0 commit comments