3
3
4
4
If a function gets defined once and could be used over and over, it'll go in here.
5
5
"""
6
+ import torchvision
7
+ from typing import List
6
8
import torch
7
9
import matplotlib .pyplot as plt
8
10
import numpy as np
18
20
19
21
# Walk through an image classification directory and find out how many files (images)
20
22
# are in each subdirectory.
21
- import os
23
+
22
24
23
25
def walk_through_dir (dir_path ):
24
26
"""
@@ -35,6 +37,7 @@ def walk_through_dir(dir_path):
35
37
for dirpath , dirnames , filenames in os .walk (dir_path ):
36
38
print (f"There are { len (dirnames )} directories and { len (filenames )} images in '{ dirpath } '." )
37
39
40
+
38
41
def plot_decision_boundary (model : torch .nn .Module , X : torch .Tensor , y : torch .Tensor ):
39
42
"""Plots decision boundaries of model predicting on X in comparison to y.
40
43
@@ -166,8 +169,6 @@ def plot_loss_curves(results):
166
169
167
170
# Pred and plot image function from notebook 04
168
171
# See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
169
- from typing import List
170
- import torchvision
171
172
172
173
173
174
def pred_and_plot_image (
@@ -185,7 +186,7 @@ def pred_and_plot_image(
185
186
class_names (List[str], optional): different class names for target image. Defaults to None.
186
187
transform (_type_, optional): transform of target image. Defaults to None.
187
188
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
188
-
189
+
189
190
Returns:
190
191
Matplotlib plot of target image and model prediction as title.
191
192
@@ -236,7 +237,8 @@ def pred_and_plot_image(
236
237
plt .title (title )
237
238
plt .axis (False )
238
239
239
- def set_seeds (seed : int = 42 ):
240
+
241
+ def set_seeds (seed : int = 42 ):
240
242
"""Sets random sets for torch operations.
241
243
242
244
Args:
@@ -247,7 +249,8 @@ def set_seeds(seed: int=42):
247
249
# Set the seed for CUDA torch operations (ones that happen on the GPU)
248
250
torch .cuda .manual_seed (seed )
249
251
250
- def download_data (source : str ,
252
+
253
+ def download_data (source : str ,
251
254
destination : str ,
252
255
remove_source : bool = True ) -> Path :
253
256
"""Downloads a zipped dataset from source and unzips to destination.
@@ -256,10 +259,10 @@ def download_data(source: str,
256
259
source (str): A link to a zipped file containing data.
257
260
destination (str): A target directory to unzip data to.
258
261
remove_source (bool): Whether to remove the source after downloading and extracting.
259
-
262
+
260
263
Returns:
261
264
pathlib.Path to downloaded data.
262
-
265
+
263
266
Example usage:
264
267
download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
265
268
destination="pizza_steak_sushi")
@@ -268,13 +271,13 @@ def download_data(source: str,
268
271
data_path = Path ("data/" )
269
272
image_path = data_path / destination
270
273
271
- # If the image folder doesn't exist, download it and prepare it...
274
+ # If the image folder doesn't exist, download it and prepare it...
272
275
if image_path .is_dir ():
273
276
print (f"[INFO] { image_path } directory exists, skipping download." )
274
277
else :
275
278
print (f"[INFO] Did not find { image_path } directory, creating one..." )
276
279
image_path .mkdir (parents = True , exist_ok = True )
277
-
280
+
278
281
# Download pizza, steak, sushi data
279
282
target_file = Path (source ).name
280
283
with open (data_path / target_file , "wb" ) as f :
@@ -284,11 +287,11 @@ def download_data(source: str,
284
287
285
288
# Unzip pizza, steak, sushi data
286
289
with zipfile .ZipFile (data_path / target_file , "r" ) as zip_ref :
287
- print (f"[INFO] Unzipping { target_file } data..." )
290
+ print (f"[INFO] Unzipping { target_file } data..." )
288
291
zip_ref .extractall (image_path )
289
292
290
293
# Remove .zip file
291
294
if remove_source :
292
295
os .remove (data_path / target_file )
293
-
296
+
294
297
return image_path
0 commit comments