-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathskippd_model.py
77 lines (64 loc) · 2.27 KB
/
skippd_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""SKIPPD Model Implementation."""
import torch
from torch import nn
class SkippdModel(nn.Module):
"""SKIPPD Model.
https://github.com/yuhao-nie/Stanford-solar-forecasting-dataset
"""
def __init__(
self,
in_chans: int = 48,
num_classes: int = 1,
num_filters: int = 24,
kernel_size: int = 3,
pool_size: int = 2,
strides: int = 2,
dense_size: int = 1024,
drop_rate: float = 0.4,
) -> None:
"""Initialize a new instance of Skippd Model.
Args:
in_chans: int, number of input channels
num_filters: int, number of filters in the first convolutional layer
kernel_size: int, size of the convolutional kernel
pool_size: int, size of the pooling kernel
strides: int, stride of the pooling kernel
dense_size: int, size of the dense layer
drop_rate: float, dropout rate
"""
super(SkippdModel, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_chans, num_filters, kernel_size, padding="same"),
nn.BatchNorm2d(num_filters),
nn.ReLU(),
nn.MaxPool2d(pool_size, strides),
)
self.conv2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters * 2, kernel_size, padding="same"),
nn.BatchNorm2d(num_filters * 2),
nn.ReLU(),
nn.MaxPool2d(pool_size, strides),
)
self.flatten = nn.Flatten()
# Calculate the output size of the Flatten layer
conv_out_size: int = (
num_filters * 2 * (64 // (pool_size * strides)) ** 2
) # assuming the input size is (64, 64)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, dense_size),
nn.ReLU(),
nn.Dropout(drop_rate),
nn.Linear(dense_size, dense_size),
nn.ReLU(),
nn.Dropout(drop_rate),
nn.Linear(dense_size, num_classes),
)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x_in: torch.Tensor, shape [batch_size, in_chans, 64, 64]
"""
x = self.conv1(x_in)
x = self.conv2(x)
x = self.flatten(x)
return self.fc(x)