6
6
import torch
7
7
import torch .nn as nn
8
8
import torch .nn .functional as F
9
- import pytorch_lightning as pl
9
+ import lightning as L
10
10
11
11
from lprnet .utils import decode , accuracy
12
12
@@ -33,17 +33,17 @@ def __init__(self):
33
33
nn .Mish (True ),
34
34
nn .Conv2d (32 , 32 , kernel_size = 5 ),
35
35
nn .MaxPool2d (3 , stride = 3 ),
36
- nn .Mish (True )
36
+ nn .Mish (True ),
37
37
)
38
38
# Regressor for the 3x2 affine matrix
39
39
self .fc_loc = nn .Sequential (
40
- nn .Linear (32 * 15 * 6 , 32 ),
41
- nn .Mish (True ),
42
- nn .Linear (32 , 3 * 2 )
40
+ nn .Linear (32 * 15 * 6 , 32 ), nn .Mish (True ), nn .Linear (32 , 3 * 2 )
43
41
)
44
- # Initialize the weights/bias with identity transformation
42
+ # Initialize the weights/bias with identity transformation
45
43
self .fc_loc [2 ].weight .data .zero_ ()
46
- self .fc_loc [2 ].bias .data .copy_ (torch .tensor ([1 , 0 , 0 , 0 , 1 , 0 ], dtype = torch .float ))
44
+ self .fc_loc [2 ].bias .data .copy_ (
45
+ torch .tensor ([1 , 0 , 0 , 0 , 1 , 0 ], dtype = torch .float )
46
+ )
47
47
48
48
def forward (self , x ):
49
49
xs = self .localization (x )
@@ -99,13 +99,19 @@ def __init__(self, class_num, dropout_rate):
99
99
nn .BatchNorm2d (num_features = 256 ),
100
100
nn .Mish (),
101
101
nn .Dropout (dropout_rate ),
102
- nn .Conv2d (in_channels = 256 , out_channels = class_num , kernel_size = (12 , 2 ), stride = 1 ),
102
+ nn .Conv2d (
103
+ in_channels = 256 , out_channels = class_num , kernel_size = (12 , 2 ), stride = 1
104
+ ),
103
105
nn .BatchNorm2d (num_features = class_num ),
104
106
nn .Mish (),
105
107
)
106
108
self .container = nn .Sequential (
107
- nn .Conv2d (in_channels = 256 + class_num + 128 + 64 , out_channels = self .class_num , kernel_size = (1 , 1 ),
108
- stride = (1 , 1 )),
109
+ nn .Conv2d (
110
+ in_channels = 256 + class_num + 128 + 64 ,
111
+ out_channels = self .class_num ,
112
+ kernel_size = (1 , 1 ),
113
+ stride = (1 , 1 ),
114
+ ),
109
115
)
110
116
111
117
def forward (self , x ):
@@ -133,12 +139,14 @@ def forward(self, x):
133
139
return logits
134
140
135
141
136
- class LPRNet (pl .LightningModule ):
142
+ class LPRNet (L .LightningModule ):
137
143
def __init__ (self , args : Optional [Namespace ] = None ):
138
144
super ().__init__ ()
139
145
self .save_hyperparameters (args )
140
146
self .STNet = _STNet ()
141
- self .LPRNet = _LPRNet (class_num = len (self .hparams .chars ), dropout_rate = self .hparams .dropout_rate )
147
+ self .LPRNet = _LPRNet (
148
+ class_num = len (self .hparams .chars ), dropout_rate = self .hparams .dropout_rate
149
+ )
142
150
143
151
def forward (self , x ):
144
152
return self .LPRNet (self .STNet (x ))
@@ -150,10 +158,17 @@ def training_step(self, batch, batch_idx):
150
158
logits = self (imgs )
151
159
log_probs = logits .permute (2 , 0 , 1 )
152
160
log_probs = log_probs .log_softmax (2 ).requires_grad_ ()
153
- input_lengths , target_lengths = sparse_tuple_for_ctc (self .hparams .t_length , lengths )
154
- loss = F .ctc_loss (log_probs = log_probs , targets = labels ,
155
- input_lengths = input_lengths , target_lengths = target_lengths ,
156
- blank = len (self .hparams .chars ) - 1 , reduction = 'mean' )
161
+ input_lengths , target_lengths = sparse_tuple_for_ctc (
162
+ self .hparams .t_length , lengths
163
+ )
164
+ loss = F .ctc_loss (
165
+ log_probs = log_probs ,
166
+ targets = labels ,
167
+ input_lengths = input_lengths ,
168
+ target_lengths = target_lengths ,
169
+ blank = len (self .hparams .chars ) - 1 ,
170
+ reduction = "mean" ,
171
+ )
157
172
acc = accuracy (logits , labels , lengths , self .hparams .chars )
158
173
159
174
self .log ("train-loss" , abs (loss ), prog_bar = True , logger = True , sync_dist = True )
@@ -167,10 +182,17 @@ def validation_step(self, batch, batch_idx):
167
182
logits = self (imgs )
168
183
log_probs = logits .permute (2 , 0 , 1 )
169
184
log_probs = log_probs .log_softmax (2 ).requires_grad_ ()
170
- input_lengths , target_lengths = sparse_tuple_for_ctc (self .hparams .t_length , lengths )
171
- loss = F .ctc_loss (log_probs = log_probs , targets = labels ,
172
- input_lengths = input_lengths , target_lengths = target_lengths ,
173
- blank = len (self .hparams .chars ) - 1 , reduction = 'mean' )
185
+ input_lengths , target_lengths = sparse_tuple_for_ctc (
186
+ self .hparams .t_length , lengths
187
+ )
188
+ loss = F .ctc_loss (
189
+ log_probs = log_probs ,
190
+ targets = labels ,
191
+ input_lengths = input_lengths ,
192
+ target_lengths = target_lengths ,
193
+ blank = len (self .hparams .chars ) - 1 ,
194
+ reduction = "mean" ,
195
+ )
174
196
acc = accuracy (logits , labels , lengths , self .hparams .chars )
175
197
176
198
self .log ("val-loss" , abs (loss ), prog_bar = True , logger = True , sync_dist = True )
@@ -179,14 +201,22 @@ def validation_step(self, batch, batch_idx):
179
201
def test_step (self , batch , batch_idx ):
180
202
imgs , labels , lengths = batch
181
203
import time
204
+
182
205
start = time .time ()
183
206
logits = self (imgs )
184
207
log_probs = logits .permute (2 , 0 , 1 )
185
208
log_probs = log_probs .log_softmax (2 ).requires_grad_ ()
186
- input_lengths , target_lengths = sparse_tuple_for_ctc (self .hparams .t_length , lengths )
187
- loss = F .ctc_loss (log_probs = log_probs , targets = labels ,
188
- input_lengths = input_lengths , target_lengths = target_lengths ,
189
- blank = len (self .hparams .chars ) - 1 , reduction = 'mean' )
209
+ input_lengths , target_lengths = sparse_tuple_for_ctc (
210
+ self .hparams .t_length , lengths
211
+ )
212
+ loss = F .ctc_loss (
213
+ log_probs = log_probs ,
214
+ targets = labels ,
215
+ input_lengths = input_lengths ,
216
+ target_lengths = target_lengths ,
217
+ blank = len (self .hparams .chars ) - 1 ,
218
+ reduction = "mean" ,
219
+ )
190
220
acc = accuracy (logits , labels , lengths , self .hparams .chars )
191
221
end = time .time ()
192
222
@@ -204,17 +234,27 @@ def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
204
234
return predict
205
235
206
236
def configure_optimizers (self ):
207
- optimizer = torch .optim .Adam ([{'params' : self .STNet .parameters (),
208
- 'weight_decay' : self .hparams .weight_decay },
209
- {'params' : self .LPRNet .parameters ()}],
210
- lr = self .hparams .lr )
211
- scheduler = torch .optim .lr_scheduler .CosineAnnealingWarmRestarts (optimizer , 10 , 2 , 0.0001 , - 1 )
212
- return {"optimizer" : optimizer ,
213
- "lr_scheduler" : {
214
- "scheduler" : scheduler ,
215
- "interval" : "step" ,
216
- "frequency" : 1 ,
217
- "monitor" : "val-loss" ,
218
- "strict" : True ,
219
- "name" : "lr"
220
- }}
237
+ optimizer = torch .optim .Adam (
238
+ [
239
+ {
240
+ "params" : self .STNet .parameters (),
241
+ "weight_decay" : self .hparams .weight_decay ,
242
+ },
243
+ {"params" : self .LPRNet .parameters ()},
244
+ ],
245
+ lr = self .hparams .lr ,
246
+ )
247
+ scheduler = torch .optim .lr_scheduler .CosineAnnealingWarmRestarts (
248
+ optimizer , 10 , 2 , 0.0001 , - 1
249
+ )
250
+ return {
251
+ "optimizer" : optimizer ,
252
+ "lr_scheduler" : {
253
+ "scheduler" : scheduler ,
254
+ "interval" : "step" ,
255
+ "frequency" : 1 ,
256
+ "monitor" : "val-loss" ,
257
+ "strict" : True ,
258
+ "name" : "lr" ,
259
+ },
260
+ }
0 commit comments