Skip to content

Commit 8d38e3e

Browse files
committed
refactor: update code to lightning>=2.x.x
1 parent e00f4b0 commit 8d38e3e

9 files changed

+156
-98
lines changed
363 Bytes
Binary file not shown.
Binary file not shown.
6.84 KB
Binary file not shown.
2.38 KB
Binary file not shown.

lprnet/datamodule.py

+45-32
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import cv2
77
from torch.utils.data import Dataset, DataLoader
88
from imutils import paths
9-
import pytorch_lightning as pl
9+
import lightning as L
1010

1111
from lprnet.utils import encode
1212

@@ -24,8 +24,11 @@ def resize_pad(img, size):
2424
sizeas = (int(w * ash), int(h * ash))
2525

2626
pic1 = cv2.resize(pic1, dsize=sizeas)
27-
base_pic[int(size[1] / 2 - sizeas[1] / 2):int(size[1] / 2 + sizeas[1] / 2),
28-
int(size[0] / 2 - sizeas[0] / 2):int(size[0] / 2 + sizeas[0] / 2), :] = pic1
27+
base_pic[
28+
int(size[1] / 2 - sizeas[1] / 2) : int(size[1] / 2 + sizeas[1] / 2),
29+
int(size[0] / 2 - sizeas[0] / 2) : int(size[0] / 2 + sizeas[0] / 2),
30+
:,
31+
] = pic1
2932

3033
return base_pic
3134

@@ -51,20 +54,20 @@ def __init__(self, args, stage, PreprocFun=None):
5154
self.img_paths = []
5255
self.img_size = self.args.img_size
5356

54-
if stage == 'train':
57+
if stage == "train":
5558
self.img_dir = self.args.train_dir
56-
elif stage == 'valid':
59+
elif stage == "valid":
5760
self.img_dir = self.args.valid_dir
58-
elif stage == 'test':
61+
elif stage == "test":
5962
self.img_dir = self.args.test_dir
60-
elif stage == 'predict':
63+
elif stage == "predict":
6164
self.img_dir = self.args.test_dir
6265
else:
6366
assert f"No Such Stage. Your input -> {self.stage}"
6467

6568
self.img_paths = [img_path for img_path in paths.list_images(self.img_dir)]
6669

67-
if stage == 'train':
70+
if stage == "train":
6871
random.shuffle(self.img_paths)
6972

7073
if PreprocFun is not None:
@@ -95,7 +98,7 @@ def __getitem__(self, index):
9598
return Image, label, len(label)
9699

97100
def transform(self, img):
98-
img = img.astype('float32')
101+
img = img.astype("float32")
99102
img -= 127.5
100103
img *= 0.0078125
101104
img = np.transpose(img, (2, 0, 1))
@@ -104,13 +107,15 @@ def transform(self, img):
104107

105108
def check(self, label):
106109
# kor_plate_pattern = re.compile('[가-힣]{0,5}[0-9]{0,3}[가-힣][0-9]{4}')
107-
idn_plate_pattern = re.compile('[A-Z]{0,3}[0-9]{0,4}[A-Z]{0,3}')
108-
plate_name = idn_plate_pattern.findall(''.join([self.args.chars[c] for c in label]))
110+
idn_plate_pattern = re.compile("[A-Z]{0,3}[0-9]{0,4}[A-Z]{0,3}")
111+
plate_name = idn_plate_pattern.findall(
112+
"".join([self.args.chars[c] for c in label])
113+
)
109114

110115
return True if plate_name else False
111116

112117

113-
class DataModule(pl.LightningDataModule):
118+
class DataModule(L.LightningDataModule):
114119
def __init__(self, args):
115120
super().__init__()
116121
self.args = args
@@ -131,29 +136,37 @@ def setup(self, stage: str):
131136
self.predict = LPRNetDataset(self.args, "predict")
132137

133138
def train_dataloader(self):
134-
return DataLoader(self.train,
135-
batch_size=self.args.batch_size,
136-
shuffle=True,
137-
num_workers=4,
138-
collate_fn=collate_fn)
139+
return DataLoader(
140+
self.train,
141+
batch_size=self.args.batch_size,
142+
shuffle=True,
143+
num_workers=4,
144+
collate_fn=collate_fn,
145+
)
139146

140147
def val_dataloader(self):
141-
return DataLoader(self.val,
142-
batch_size=self.args.batch_size,
143-
shuffle=False,
144-
num_workers=4,
145-
collate_fn=collate_fn)
148+
return DataLoader(
149+
self.val,
150+
batch_size=self.args.batch_size,
151+
shuffle=False,
152+
num_workers=4,
153+
collate_fn=collate_fn,
154+
)
146155

147156
def test_dataloader(self):
148-
return DataLoader(self.test,
149-
batch_size=self.args.batch_size,
150-
shuffle=False,
151-
num_workers=4,
152-
collate_fn=collate_fn)
157+
return DataLoader(
158+
self.test,
159+
batch_size=self.args.batch_size,
160+
shuffle=False,
161+
num_workers=4,
162+
collate_fn=collate_fn,
163+
)
153164

154165
def predict_dataloader(self):
155-
return DataLoader(self.predict,
156-
batch_size=self.args.batch_size,
157-
shuffle=False,
158-
num_workers=4,
159-
collate_fn=collate_fn)
166+
return DataLoader(
167+
self.predict,
168+
batch_size=self.args.batch_size,
169+
shuffle=False,
170+
num_workers=4,
171+
collate_fn=collate_fn,
172+
)

lprnet/lprnet.py

+78-38
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
9-
import pytorch_lightning as pl
9+
import lightning as L
1010

1111
from lprnet.utils import decode, accuracy
1212

@@ -33,17 +33,17 @@ def __init__(self):
3333
nn.Mish(True),
3434
nn.Conv2d(32, 32, kernel_size=5),
3535
nn.MaxPool2d(3, stride=3),
36-
nn.Mish(True)
36+
nn.Mish(True),
3737
)
3838
# Regressor for the 3x2 affine matrix
3939
self.fc_loc = nn.Sequential(
40-
nn.Linear(32 * 15 * 6, 32),
41-
nn.Mish(True),
42-
nn.Linear(32, 3 * 2)
40+
nn.Linear(32 * 15 * 6, 32), nn.Mish(True), nn.Linear(32, 3 * 2)
4341
)
44-
# Initialize the weights/bias with identity transformation
42+
# Initialize the weights/bias with identity transformation
4543
self.fc_loc[2].weight.data.zero_()
46-
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
44+
self.fc_loc[2].bias.data.copy_(
45+
torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
46+
)
4747

4848
def forward(self, x):
4949
xs = self.localization(x)
@@ -99,13 +99,19 @@ def __init__(self, class_num, dropout_rate):
9999
nn.BatchNorm2d(num_features=256),
100100
nn.Mish(),
101101
nn.Dropout(dropout_rate),
102-
nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(12, 2), stride=1),
102+
nn.Conv2d(
103+
in_channels=256, out_channels=class_num, kernel_size=(12, 2), stride=1
104+
),
103105
nn.BatchNorm2d(num_features=class_num),
104106
nn.Mish(),
105107
)
106108
self.container = nn.Sequential(
107-
nn.Conv2d(in_channels=256 + class_num + 128 + 64, out_channels=self.class_num, kernel_size=(1, 1),
108-
stride=(1, 1)),
109+
nn.Conv2d(
110+
in_channels=256 + class_num + 128 + 64,
111+
out_channels=self.class_num,
112+
kernel_size=(1, 1),
113+
stride=(1, 1),
114+
),
109115
)
110116

111117
def forward(self, x):
@@ -133,12 +139,14 @@ def forward(self, x):
133139
return logits
134140

135141

136-
class LPRNet(pl.LightningModule):
142+
class LPRNet(L.LightningModule):
137143
def __init__(self, args: Optional[Namespace] = None):
138144
super().__init__()
139145
self.save_hyperparameters(args)
140146
self.STNet = _STNet()
141-
self.LPRNet = _LPRNet(class_num=len(self.hparams.chars), dropout_rate=self.hparams.dropout_rate)
147+
self.LPRNet = _LPRNet(
148+
class_num=len(self.hparams.chars), dropout_rate=self.hparams.dropout_rate
149+
)
142150

143151
def forward(self, x):
144152
return self.LPRNet(self.STNet(x))
@@ -150,10 +158,17 @@ def training_step(self, batch, batch_idx):
150158
logits = self(imgs)
151159
log_probs = logits.permute(2, 0, 1)
152160
log_probs = log_probs.log_softmax(2).requires_grad_()
153-
input_lengths, target_lengths = sparse_tuple_for_ctc(self.hparams.t_length, lengths)
154-
loss = F.ctc_loss(log_probs=log_probs, targets=labels,
155-
input_lengths=input_lengths, target_lengths=target_lengths,
156-
blank=len(self.hparams.chars) - 1, reduction='mean')
161+
input_lengths, target_lengths = sparse_tuple_for_ctc(
162+
self.hparams.t_length, lengths
163+
)
164+
loss = F.ctc_loss(
165+
log_probs=log_probs,
166+
targets=labels,
167+
input_lengths=input_lengths,
168+
target_lengths=target_lengths,
169+
blank=len(self.hparams.chars) - 1,
170+
reduction="mean",
171+
)
157172
acc = accuracy(logits, labels, lengths, self.hparams.chars)
158173

159174
self.log("train-loss", abs(loss), prog_bar=True, logger=True, sync_dist=True)
@@ -167,10 +182,17 @@ def validation_step(self, batch, batch_idx):
167182
logits = self(imgs)
168183
log_probs = logits.permute(2, 0, 1)
169184
log_probs = log_probs.log_softmax(2).requires_grad_()
170-
input_lengths, target_lengths = sparse_tuple_for_ctc(self.hparams.t_length, lengths)
171-
loss = F.ctc_loss(log_probs=log_probs, targets=labels,
172-
input_lengths=input_lengths, target_lengths=target_lengths,
173-
blank=len(self.hparams.chars) - 1, reduction='mean')
185+
input_lengths, target_lengths = sparse_tuple_for_ctc(
186+
self.hparams.t_length, lengths
187+
)
188+
loss = F.ctc_loss(
189+
log_probs=log_probs,
190+
targets=labels,
191+
input_lengths=input_lengths,
192+
target_lengths=target_lengths,
193+
blank=len(self.hparams.chars) - 1,
194+
reduction="mean",
195+
)
174196
acc = accuracy(logits, labels, lengths, self.hparams.chars)
175197

176198
self.log("val-loss", abs(loss), prog_bar=True, logger=True, sync_dist=True)
@@ -179,14 +201,22 @@ def validation_step(self, batch, batch_idx):
179201
def test_step(self, batch, batch_idx):
180202
imgs, labels, lengths = batch
181203
import time
204+
182205
start = time.time()
183206
logits = self(imgs)
184207
log_probs = logits.permute(2, 0, 1)
185208
log_probs = log_probs.log_softmax(2).requires_grad_()
186-
input_lengths, target_lengths = sparse_tuple_for_ctc(self.hparams.t_length, lengths)
187-
loss = F.ctc_loss(log_probs=log_probs, targets=labels,
188-
input_lengths=input_lengths, target_lengths=target_lengths,
189-
blank=len(self.hparams.chars) - 1, reduction='mean')
209+
input_lengths, target_lengths = sparse_tuple_for_ctc(
210+
self.hparams.t_length, lengths
211+
)
212+
loss = F.ctc_loss(
213+
log_probs=log_probs,
214+
targets=labels,
215+
input_lengths=input_lengths,
216+
target_lengths=target_lengths,
217+
blank=len(self.hparams.chars) - 1,
218+
reduction="mean",
219+
)
190220
acc = accuracy(logits, labels, lengths, self.hparams.chars)
191221
end = time.time()
192222

@@ -204,17 +234,27 @@ def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
204234
return predict
205235

206236
def configure_optimizers(self):
207-
optimizer = torch.optim.Adam([{'params': self.STNet.parameters(),
208-
'weight_decay': self.hparams.weight_decay},
209-
{'params': self.LPRNet.parameters()}],
210-
lr=self.hparams.lr)
211-
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2, 0.0001, -1)
212-
return {"optimizer": optimizer,
213-
"lr_scheduler": {
214-
"scheduler": scheduler,
215-
"interval": "step",
216-
"frequency": 1,
217-
"monitor": "val-loss",
218-
"strict": True,
219-
"name": "lr"
220-
}}
237+
optimizer = torch.optim.Adam(
238+
[
239+
{
240+
"params": self.STNet.parameters(),
241+
"weight_decay": self.hparams.weight_decay,
242+
},
243+
{"params": self.LPRNet.parameters()},
244+
],
245+
lr=self.hparams.lr,
246+
)
247+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
248+
optimizer, 10, 2, 0.0001, -1
249+
)
250+
return {
251+
"optimizer": optimizer,
252+
"lr_scheduler": {
253+
"scheduler": scheduler,
254+
"interval": "step",
255+
"frequency": 1,
256+
"monitor": "val-loss",
257+
"strict": True,
258+
"name": "lr",
259+
},
260+
}

lprnet/utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def encode(imgname: str, chars: List[str]):
1717
label.append(chars_dict[imgname[i:j]])
1818
i = j
1919
else:
20-
assert 0, f'no such char in {imgname}'
20+
assert 0, f"no such char in {imgname}"
2121

2222
return label
2323

@@ -32,7 +32,7 @@ def decode(preds, chars):
3232
for j in range(pred.shape[1]):
3333
pred_label.append(np.argmax(pred[:, j], axis=0))
3434
no_repeat_blank_label = list()
35-
pre_c = ''
35+
pre_c = ""
3636
for c in pred_label: # dropout repeated label and blank label
3737
if (pre_c == c) or (c == len(chars) - 1):
3838
if c == len(chars) - 1:
@@ -58,7 +58,7 @@ def accuracy(logits, labels, lengths, chars):
5858
TP, total = 0, 0
5959
start = 0
6060
for i, length in enumerate(lengths):
61-
label = labels[start:start + length]
61+
label = labels[start : start + length]
6262
start += length
6363
if np.array_equal(np.array(pred_labels[i]), label.cpu().numpy()):
6464
TP += 1
@@ -72,19 +72,20 @@ def tensor2numpy(inp):
7272
inp = inp.squeeze(0).cpu()
7373
inp = inp.detach().numpy().transpose((1, 2, 0))
7474
inp = 127.5 + inp / 0.0078125
75-
inp = inp.astype('uint8')
75+
inp = inp.astype("uint8")
7676

7777
return inp
7878

7979

8080
def numpy2tensor(img: np.ndarray, img_size: Sequence[int]):
8181
# convert a numpy image to tensor
8282
import cv2
83+
8384
height, width, _ = img.shape
8485

8586
if height != img_size[1] or width != img_size[0]:
8687
img = cv2.resize(img, img_size, interpolation=cv2.INTER_CUBIC)
87-
img = img.astype('float32')
88+
img = img.astype("float32")
8889
img -= 127.5
8990
img *= 0.0078125
9091
img = np.transpose(img, (2, 0, 1))

test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77
import warnings
88
import yaml
99
import torch
10-
from pytorch_lightning import Trainer
10+
11+
import lightning as L
1112

1213
from lprnet import LPRNet, DataModule
1314

1415
warnings.filterwarnings("ignore")
1516

1617

17-
if __name__ == '__main__':
18-
with open('config/idn_config.yaml') as f:
18+
if __name__ == "__main__":
19+
with open("config/idn_config.yaml") as f:
1920
args = Namespace(**yaml.load(f, Loader=yaml.FullLoader))
2021

2122
load_model_start = time.time()
@@ -29,7 +30,7 @@
2930

3031
dm = DataModule(args)
3132

32-
trainer = Trainer(
33+
trainer = L.Trainer(
3334
accelerator="auto",
3435
precision=16,
3536
devices=torch.cuda.device_count(),

0 commit comments

Comments
 (0)