Skip to content

Commit 61ae35c

Browse files
authored
Use sklearn in runif (Lightning-AI#15426)
* Use sklearn in runif * test by removing sklearn dep * remove repeated code * seed
1 parent 7ee0994 commit 61ae35c

26 files changed

+114
-177
lines changed

tests/legacy/simple_classif_training.py

+3-128
Original file line numberDiff line numberDiff line change
@@ -14,141 +14,16 @@
1414
import os
1515

1616
import torch
17-
import torch.nn.functional as F
18-
from sklearn.datasets import make_classification
19-
from sklearn.model_selection import train_test_split
20-
from torch import nn
21-
from torch.utils.data import DataLoader, Dataset
22-
from torchmetrics import Accuracy
2317

2418
import pytorch_lightning as pl
25-
from pytorch_lightning import LightningDataModule, LightningModule, seed_everything
19+
from pytorch_lightning import seed_everything
2620
from pytorch_lightning.callbacks import EarlyStopping
21+
from tests_pytorch.helpers.datamodules import ClassifDataModule
22+
from tests_pytorch.helpers.simple_models import ClassificationModel
2723

2824
PATH_LEGACY = os.path.dirname(__file__)
2925

3026

31-
class SklearnDataset(Dataset):
32-
def __init__(self, x, y, x_type, y_type):
33-
self.x = x
34-
self.y = y
35-
self._x_type = x_type
36-
self._y_type = y_type
37-
38-
def __getitem__(self, idx):
39-
return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type)
40-
41-
def __len__(self):
42-
return len(self.y)
43-
44-
45-
class SklearnDataModule(LightningDataModule):
46-
def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128):
47-
super().__init__()
48-
self.batch_size = batch_size
49-
self._x, self._y = sklearn_dataset
50-
self._split_data()
51-
self._x_type = x_type
52-
self._y_type = y_type
53-
54-
def _split_data(self):
55-
self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
56-
self._x, self._y, test_size=0.20, random_state=42
57-
)
58-
self.x_train, self.x_predict, self.y_train, self.y_predict = train_test_split(
59-
self._x, self._y, test_size=0.20, random_state=42
60-
)
61-
self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
62-
self.x_train, self.y_train, test_size=0.40, random_state=42
63-
)
64-
65-
def train_dataloader(self):
66-
return DataLoader(
67-
SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type),
68-
shuffle=True,
69-
batch_size=self.batch_size,
70-
)
71-
72-
def val_dataloader(self):
73-
return DataLoader(
74-
SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size
75-
)
76-
77-
def test_dataloader(self):
78-
return DataLoader(
79-
SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size
80-
)
81-
82-
def predict_dataloader(self):
83-
return DataLoader(
84-
SklearnDataset(self.x_predict, self.y_predict, self._x_type, self._y_type), batch_size=self.batch_size
85-
)
86-
87-
88-
class ClassifDataModule(SklearnDataModule):
89-
def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):
90-
data = make_classification(
91-
n_samples=length,
92-
n_features=num_features,
93-
n_classes=num_classes,
94-
n_clusters_per_class=2,
95-
n_informative=int(num_features / num_classes),
96-
random_state=42,
97-
)
98-
super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size)
99-
100-
101-
class ClassificationModel(LightningModule):
102-
def __init__(self, num_features=24, num_classes=3, lr=0.01):
103-
super().__init__()
104-
self.save_hyperparameters()
105-
106-
self.lr = lr
107-
for i in range(3):
108-
setattr(self, f"layer_{i}", nn.Linear(num_features, num_features))
109-
setattr(self, f"layer_{i}a", torch.nn.ReLU())
110-
setattr(self, "layer_end", nn.Linear(num_features, num_classes))
111-
112-
self.train_acc = Accuracy()
113-
self.valid_acc = Accuracy()
114-
self.test_acc = Accuracy()
115-
116-
def forward(self, x):
117-
x = self.layer_0(x)
118-
x = self.layer_0a(x)
119-
x = self.layer_1(x)
120-
x = self.layer_1a(x)
121-
x = self.layer_2(x)
122-
x = self.layer_2a(x)
123-
x = self.layer_end(x)
124-
logits = F.softmax(x, dim=1)
125-
return logits
126-
127-
def configure_optimizers(self):
128-
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
129-
return [optimizer], []
130-
131-
def training_step(self, batch, batch_idx):
132-
x, y = batch
133-
logits = self.forward(x)
134-
loss = F.cross_entropy(logits, y)
135-
self.log("train_loss", loss, prog_bar=True)
136-
self.log("train_acc", self.train_acc(logits, y), prog_bar=True)
137-
return {"loss": loss}
138-
139-
def validation_step(self, batch, batch_idx):
140-
x, y = batch
141-
logits = self.forward(x)
142-
self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False)
143-
self.log("val_acc", self.valid_acc(logits, y), prog_bar=True)
144-
145-
def test_step(self, batch, batch_idx):
146-
x, y = batch
147-
logits = self.forward(x)
148-
self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False)
149-
self.log("test_acc", self.test_acc(logits, y), prog_bar=True)
150-
151-
15227
def main_train(dir_path, max_epochs: int = 20):
15328
seed_everything(42)
15429
stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)

tests/tests_pytorch/accelerators/test_hpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def test_all_stages(tmpdir, hpus):
7676
trainer.predict(model)
7777

7878

79-
@RunIf(hpu=True)
80-
@mock.patch.dict(os.environ, os.environ.copy())
79+
@RunIf(hpu=True, sklearn=True)
80+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
8181
def test_optimization(tmpdir):
8282
seed_everything(42)
8383

tests/tests_pytorch/accelerators/test_ipu.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def test_inference_only(tmpdir, devices):
149149
trainer.predict(model)
150150

151151

152-
@RunIf(ipu=True)
152+
@RunIf(ipu=True, sklearn=True)
153+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
153154
def test_optimization(tmpdir):
154155
seed_everything(42)
155156

tests/tests_pytorch/callbacks/test_early_stopping.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import logging
1515
import math
16+
import os
1617
import pickle
1718
from typing import List, Optional
1819
from unittest import mock
@@ -56,6 +57,8 @@ def on_train_epoch_end(self, trainer, pl_module):
5657
self.saved_states.append(self.state_dict().copy())
5758

5859

60+
@RunIf(sklearn=True)
61+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
5962
def test_resume_early_stopping_from_checkpoint(tmpdir):
6063
"""Prevent regressions to bugs:
6164
@@ -98,6 +101,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
98101
new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath)
99102

100103

104+
@RunIf(sklearn=True)
101105
def test_early_stopping_no_extraneous_invocations(tmpdir):
102106
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
103107
model = ClassificationModel()
@@ -195,6 +199,7 @@ def test_pickling(tmpdir):
195199
assert vars(early_stopping) == vars(early_stopping_loaded)
196200

197201

202+
@RunIf(sklearn=True)
198203
def test_early_stopping_no_val_step(tmpdir):
199204
"""Test that early stopping callback falls back to training metrics when no validation defined."""
200205

tests/tests_pytorch/callbacks/test_lr_monitor.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.demos.boring_classes import BoringModel
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424
from tests_pytorch.helpers.datamodules import ClassifDataModule
25+
from tests_pytorch.helpers.runif import RunIf
2526
from tests_pytorch.helpers.simple_models import ClassificationModel
2627

2728

@@ -284,6 +285,7 @@ def configure_optimizers(self):
284285
assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values())
285286

286287

288+
@RunIf(sklearn=True)
287289
def test_lr_monitor_param_groups(tmpdir):
288290
"""Test that learning rates are extracted and logged for single lr scheduler."""
289291

tests/tests_pytorch/callbacks/test_quantization.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
@pytest.mark.parametrize("observe", ["average", "histogram"])
3636
@pytest.mark.parametrize("fuse", [True, False])
3737
@pytest.mark.parametrize("convert", [True, False])
38-
@RunIf(quantization=True, max_torch="1.11")
38+
@RunIf(quantization=True, sklearn=True, max_torch="1.11")
3939
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
4040
"""Parity test for quant model."""
4141
cuda_available = CUDAAccelerator.is_available()
@@ -100,7 +100,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
100100
assert torch.allclose(org_score, quant2_score, atol=0.45)
101101

102102

103-
@RunIf(quantization=True)
103+
@RunIf(quantization=True, sklearn=True)
104104
def test_quantize_torchscript(tmpdir):
105105
"""Test converting to torchscipt."""
106106
dm = RegressDataModule()
@@ -116,7 +116,7 @@ def test_quantize_torchscript(tmpdir):
116116
tsmodel(tsmodel.quant(batch[0]))
117117

118118

119-
@RunIf(quantization=True)
119+
@RunIf(quantization=True, sklearn=True)
120120
def test_quantization_exceptions(tmpdir):
121121
"""Test wrong fuse layers."""
122122
with pytest.raises(MisconfigurationException, match="Unsupported qconfig"):
@@ -157,7 +157,7 @@ def custom_trigger_last(trainer):
157157
"trigger_fn,expected_count",
158158
[(None, 9), (3, 3), (custom_trigger_never, 0), (custom_trigger_even, 5), (custom_trigger_last, 2)],
159159
)
160-
@RunIf(quantization=True)
160+
@RunIf(quantization=True, sklearn=True)
161161
def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], expected_count: int):
162162
"""Test how many times the quant is called."""
163163
dm = RegressDataModule()
@@ -216,7 +216,7 @@ def test_quantization_disable_observers(tmpdir, observer_enabled_stages):
216216
)
217217

218218

219-
@RunIf(quantization=True)
219+
@RunIf(quantization=True, sklearn=True)
220220
def test_quantization_val_test_predict(tmpdir):
221221
"""Test the default quantization aware training not affected by validating, testing and predicting."""
222222
seed_everything(42)

tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import pytorch_lightning as pl
2424
from pytorch_lightning import Callback, Trainer
2525
from tests_pytorch import _PATH_LEGACY
26+
from tests_pytorch.helpers.datamodules import ClassifDataModule
27+
from tests_pytorch.helpers.runif import RunIf
28+
from tests_pytorch.helpers.simple_models import ClassificationModel
2629

2730
LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
2831
CHECKPOINT_EXTENSION = ".ckpt"
@@ -32,18 +35,17 @@
3235

3336

3437
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
38+
@RunIf(sklearn=True)
3539
def test_load_legacy_checkpoints(tmpdir, pl_version: str):
3640
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
3741
with patch("sys.path", [PATH_LEGACY] + sys.path):
38-
from simple_classif_training import ClassifDataModule, ClassificationModel
39-
4042
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
4143
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
4244
path_ckpt = path_ckpts[-1]
4345

44-
model = ClassificationModel.load_from_checkpoint(path_ckpt)
46+
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
4547
trainer = Trainer(default_root_dir=str(tmpdir))
46-
dm = ClassifDataModule()
48+
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
4749
res = trainer.test(model, datamodule=dm)
4850
assert res[0]["test_loss"] <= 0.7
4951
assert res[0]["test_acc"] >= 0.85
@@ -62,6 +64,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
6264

6365

6466
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
67+
@RunIf(sklearn=True)
6568
def test_legacy_ckpt_threading(tmpdir, pl_version: str):
6669
def load_model():
6770
import torch
@@ -84,17 +87,16 @@ def load_model():
8487

8588

8689
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
90+
@RunIf(sklearn=True)
8791
def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
8892
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
8993
with patch("sys.path", [PATH_LEGACY] + sys.path):
90-
from simple_classif_training import ClassifDataModule, ClassificationModel
91-
9294
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
9395
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
9496
path_ckpt = path_ckpts[-1]
9597

96-
dm = ClassifDataModule()
97-
model = ClassificationModel()
98+
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
99+
model = ClassificationModel(num_features=24)
98100
stop = LimitNbEpochs(1)
99101

100102
trainer = Trainer(

tests/tests_pytorch/core/test_datamodules.py

+3
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def test_dm_pickle_after_init():
148148
pickle.dumps(dm)
149149

150150

151+
@RunIf(sklearn=True)
151152
def test_train_loop_only(tmpdir):
152153
seed_everything(7)
153154

@@ -169,6 +170,7 @@ def test_train_loop_only(tmpdir):
169170
assert trainer.callback_metrics["train_loss"] < 1.1
170171

171172

173+
@RunIf(sklearn=True)
172174
def test_train_val_loop_only(tmpdir):
173175
seed_everything(7)
174176

@@ -226,6 +228,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
226228
assert dm.my_state_dict == {"my": "state_dict"}
227229

228230

231+
@RunIf(sklearn=True)
229232
def test_full_loop(tmpdir):
230233
seed_everything(7)
231234

0 commit comments

Comments
 (0)