-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paths3_download.py
150 lines (131 loc) · 6.1 KB
/
s3_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json
import tarfile
from pathlib import Path
from typing import Optional, Union
import boto3
import numpy as np
import numpy.typing as npt
from botocore import UNSIGNED
from botocore.config import Config
from tqdm import tqdm
BUCKET_NAME = "xanadu-aurora-data"
LOCAL_DATA_DIRECTORY_ROOT = Path("downloaded_data")
def download_and_load_file(
s3_file_path: str,
local_directory: Optional[str] = None,
remove_local_file: bool = False,
force_redownload: bool = False,
extract_if_zip: bool = False,
) -> Union[npt.NDArray, list, None]:
"""Downloads a file from an S3 bucket and loads its content based on the file
extension.
Args:
s3_file_path: the path to the file in the S3 bucket, relative to the
bucket root
local_directory: the local directory where the file should be saved; if
None, uses default directory (LOCAL_DATA_DIRECTORY_ROOT)
remove_local_file: if True, deletes the local file after loading its
content; if local_directory is None, also attempts to remove empty
directories in LOCAL_DATA_DIRECTORY_ROOT
force_redownload: if True, forces the download even if the file already
exists locally
extract_if_zip: if True, extracts the file if it is a .tar.gz archive
Returns:
the content of the file, either as a numpy array (for .npy files) or a
list (for .json files); None for extracted archives
Raises:
ValueError: if the s3_file_path is not relative to the bucket root or
if the file extension is unsupported
RuntimeError: if there is an error during the download or file loading
process
"""
if s3_file_path.startswith("/"):
raise ValueError("s3_file_path must be relative to the bucket root")
# in case there are multiple suffixes like for instance .tar.gz
file_extension = ""
for suffix in Path(s3_file_path).suffixes:
file_extension += suffix
if local_directory is not None and local_directory in [None, "", "/"]:
raise ValueError("Please specify a local_directory that is not empty or root.")
local_directory_prefix = Path(local_directory or LOCAL_DATA_DIRECTORY_ROOT)
local_file_path = local_directory_prefix / s3_file_path
local_file_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Downloading file {s3_file_path} to {local_file_path}")
def download_progress(bytes_amount):
progress_bar.update(bytes_amount)
# initialize S3 client
s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
try:
if local_file_path.exists() and not force_redownload:
# if we find it locally and we are not forcing a download, we can just load it
print(
f"File {s3_file_path} has already been downloaded at {local_file_path}"
)
else:
# get the file size
file_info = s3_client.head_object(Bucket=BUCKET_NAME, Key=s3_file_path)
file_size = file_info["ContentLength"]
# set up the progress bar
with tqdm(
total=file_size, unit="B", unit_scale=True, desc=s3_file_path
) as progress_bar:
# download file with progress callback
s3_client.download_file(
BUCKET_NAME,
s3_file_path,
local_file_path,
Callback=download_progress,
)
# load file based on extension
if file_extension.lower() == ".tar.gz":
if extract_if_zip:
print(
f"Zip file detected. Extracting file {s3_file_path} to {local_directory_prefix}"
)
with tarfile.open(local_file_path, "r:gz") as tar:
for member in tar.getmembers():
print(f"Checking {member.name}")
member_path = local_directory_prefix / member.name
if not member_path.exists():
tar.extract(member, path=local_directory_prefix)
print(f"Extracted {member.name}")
else:
print(f"File {member.name} already exists. Skipping.")
return None
elif file_extension.lower() == ".npy":
file_content = np.load(local_file_path)
elif file_extension.lower() == ".json":
with open(local_file_path, "r") as file:
file_content = json.load(file)
else:
raise ValueError(f"Unsupported file extension: {file_extension}")
return file_content
except Exception as error:
print(f"Error downloading file: {error}")
raise RuntimeError(f"Error downloading file: {error}")
finally:
# clean up -- remove temporary file and its directory structure if
# requested
if local_file_path.exists() and remove_local_file:
if local_file_path.is_file():
print(f"Delete flag is true. Deleting file {local_file_path}.")
local_file_path.unlink()
# only attempt directory cleanup if no local_directory was
# specified
if not local_directory:
root_dir = LOCAL_DATA_DIRECTORY_ROOT
if root_dir.exists():
try:
# this will remove all empty directories recursively
for dir_path in sorted(root_dir.glob("**/*"), reverse=True):
try:
dir_path.rmdir()
except OSError:
# skip if directory is not empty
continue
# try to remove root directory itself
root_dir.rmdir()
except OSError:
print(
f"Some directories in {LOCAL_DATA_DIRECTORY_ROOT} are not empty, skipping deletion."
)