|
14 | 14 | import os
|
15 | 15 |
|
16 | 16 | import torch
|
17 |
| -import torch.nn.functional as F |
18 |
| -from sklearn.datasets import make_classification |
19 |
| -from sklearn.model_selection import train_test_split |
20 |
| -from torch import nn |
21 |
| -from torch.utils.data import DataLoader, Dataset |
22 |
| -from torchmetrics import Accuracy |
23 | 17 |
|
24 | 18 | import pytorch_lightning as pl
|
25 |
| -from pytorch_lightning import LightningDataModule, LightningModule, seed_everything |
| 19 | +from pytorch_lightning import seed_everything |
26 | 20 | from pytorch_lightning.callbacks import EarlyStopping
|
| 21 | +from tests_pytorch.helpers.datamodules import ClassifDataModule |
| 22 | +from tests_pytorch.helpers.simple_models import ClassificationModel |
27 | 23 |
|
28 | 24 | PATH_LEGACY = os.path.dirname(__file__)
|
29 | 25 |
|
30 | 26 |
|
31 |
| -class SklearnDataset(Dataset): |
32 |
| - def __init__(self, x, y, x_type, y_type): |
33 |
| - self.x = x |
34 |
| - self.y = y |
35 |
| - self._x_type = x_type |
36 |
| - self._y_type = y_type |
37 |
| - |
38 |
| - def __getitem__(self, idx): |
39 |
| - return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) |
40 |
| - |
41 |
| - def __len__(self): |
42 |
| - return len(self.y) |
43 |
| - |
44 |
| - |
45 |
| -class SklearnDataModule(LightningDataModule): |
46 |
| - def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128): |
47 |
| - super().__init__() |
48 |
| - self.batch_size = batch_size |
49 |
| - self._x, self._y = sklearn_dataset |
50 |
| - self._split_data() |
51 |
| - self._x_type = x_type |
52 |
| - self._y_type = y_type |
53 |
| - |
54 |
| - def _split_data(self): |
55 |
| - self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( |
56 |
| - self._x, self._y, test_size=0.20, random_state=42 |
57 |
| - ) |
58 |
| - self.x_train, self.x_predict, self.y_train, self.y_predict = train_test_split( |
59 |
| - self._x, self._y, test_size=0.20, random_state=42 |
60 |
| - ) |
61 |
| - self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split( |
62 |
| - self.x_train, self.y_train, test_size=0.40, random_state=42 |
63 |
| - ) |
64 |
| - |
65 |
| - def train_dataloader(self): |
66 |
| - return DataLoader( |
67 |
| - SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), |
68 |
| - shuffle=True, |
69 |
| - batch_size=self.batch_size, |
70 |
| - ) |
71 |
| - |
72 |
| - def val_dataloader(self): |
73 |
| - return DataLoader( |
74 |
| - SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size |
75 |
| - ) |
76 |
| - |
77 |
| - def test_dataloader(self): |
78 |
| - return DataLoader( |
79 |
| - SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size |
80 |
| - ) |
81 |
| - |
82 |
| - def predict_dataloader(self): |
83 |
| - return DataLoader( |
84 |
| - SklearnDataset(self.x_predict, self.y_predict, self._x_type, self._y_type), batch_size=self.batch_size |
85 |
| - ) |
86 |
| - |
87 |
| - |
88 |
| -class ClassifDataModule(SklearnDataModule): |
89 |
| - def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128): |
90 |
| - data = make_classification( |
91 |
| - n_samples=length, |
92 |
| - n_features=num_features, |
93 |
| - n_classes=num_classes, |
94 |
| - n_clusters_per_class=2, |
95 |
| - n_informative=int(num_features / num_classes), |
96 |
| - random_state=42, |
97 |
| - ) |
98 |
| - super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size) |
99 |
| - |
100 |
| - |
101 |
| -class ClassificationModel(LightningModule): |
102 |
| - def __init__(self, num_features=24, num_classes=3, lr=0.01): |
103 |
| - super().__init__() |
104 |
| - self.save_hyperparameters() |
105 |
| - |
106 |
| - self.lr = lr |
107 |
| - for i in range(3): |
108 |
| - setattr(self, f"layer_{i}", nn.Linear(num_features, num_features)) |
109 |
| - setattr(self, f"layer_{i}a", torch.nn.ReLU()) |
110 |
| - setattr(self, "layer_end", nn.Linear(num_features, num_classes)) |
111 |
| - |
112 |
| - self.train_acc = Accuracy() |
113 |
| - self.valid_acc = Accuracy() |
114 |
| - self.test_acc = Accuracy() |
115 |
| - |
116 |
| - def forward(self, x): |
117 |
| - x = self.layer_0(x) |
118 |
| - x = self.layer_0a(x) |
119 |
| - x = self.layer_1(x) |
120 |
| - x = self.layer_1a(x) |
121 |
| - x = self.layer_2(x) |
122 |
| - x = self.layer_2a(x) |
123 |
| - x = self.layer_end(x) |
124 |
| - logits = F.softmax(x, dim=1) |
125 |
| - return logits |
126 |
| - |
127 |
| - def configure_optimizers(self): |
128 |
| - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
129 |
| - return [optimizer], [] |
130 |
| - |
131 |
| - def training_step(self, batch, batch_idx): |
132 |
| - x, y = batch |
133 |
| - logits = self.forward(x) |
134 |
| - loss = F.cross_entropy(logits, y) |
135 |
| - self.log("train_loss", loss, prog_bar=True) |
136 |
| - self.log("train_acc", self.train_acc(logits, y), prog_bar=True) |
137 |
| - return {"loss": loss} |
138 |
| - |
139 |
| - def validation_step(self, batch, batch_idx): |
140 |
| - x, y = batch |
141 |
| - logits = self.forward(x) |
142 |
| - self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False) |
143 |
| - self.log("val_acc", self.valid_acc(logits, y), prog_bar=True) |
144 |
| - |
145 |
| - def test_step(self, batch, batch_idx): |
146 |
| - x, y = batch |
147 |
| - logits = self.forward(x) |
148 |
| - self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False) |
149 |
| - self.log("test_acc", self.test_acc(logits, y), prog_bar=True) |
150 |
| - |
151 |
| - |
152 | 27 | def main_train(dir_path, max_epochs: int = 20):
|
153 | 28 | seed_everything(42)
|
154 | 29 | stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
|
|
0 commit comments