-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSampleDataset.py
42 lines (32 loc) · 1.16 KB
/
SampleDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import List
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
import torch
import numpy as np
import pathlib
import math
class SampleDataset():
def __init__(self):
self.last = False
self.k = 0
transform_viz_test = transforms.Compose([
transforms.Resize((840, 840)),
transforms.ToTensor(),
])
self.test_viz_data = DataLoader(datasets.ImageFolder('./datasets/sample_test', transform=transform_viz_test))
# load sample image
self.sample = DataLoader(datasets.ImageFolder('./datasets/sample_test', transform=transform_viz_test))
self.sample = self.test_viz_data
# classes
root = pathlib.Path('./datasets/sample_test/')
self.classes = sorted([j.name.split('/')[-1] for j in root.iterdir()])
print(self.classes)
self.trainloader = self.sample
self.validloader = self.sample
self.testloader = self.sample
def get_chunks(self):
if self.k == 10:
self.last = True
self.k += 1