Skip to content

Commit 0bd9445

Browse files
committed
Add test for zoneout
1 parent bd2dac2 commit 0bd9445

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

neural_sp/models/modules/zoneout.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ def __init__(self, cell, zoneout_prob_h, zoneout_prob_c):
2525
self.prob = zoneout_prob_h
2626

2727
def forward(self, inputs, state):
28+
"""Forward pass.
29+
30+
Args:
31+
inputs (FloatTensor): `[B, input_dim]'
32+
state (tuple or FloatTensor):
33+
Returns:
34+
state (tuple or FloatTensor):
35+
36+
"""
2837
return self.zoneout(state, self.cell(inputs, state), self.prob)
2938

3039
def zoneout(self, state, next_state, prob):

test/modules/test_zoneout.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#! /usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
"""Test for Zoneout."""
5+
6+
import importlib
7+
import pytest
8+
import torch
9+
10+
11+
def make_args(**kwargs):
12+
args = dict(
13+
zoneout_prob_h=0,
14+
zoneout_prob_c=0,
15+
)
16+
args.update(kwargs)
17+
return args
18+
19+
20+
@pytest.mark.parametrize(
21+
"rnn_type, args",
22+
[
23+
('lstm', {'zoneout_prob_h': 0.1}),
24+
('lstm', {'zoneout_prob_c': 0.1}),
25+
('gru', {'zoneout_prob_h': 0.1}),
26+
('gru', {'zoneout_prob_c': 0.1}),
27+
]
28+
)
29+
def test_forward(rnn_type, args):
30+
args = make_args(**args)
31+
32+
batch_size = 4
33+
cell_size = 32
34+
35+
xs = torch.FloatTensor(batch_size, cell_size)
36+
hxs = torch.zeros(batch_size, cell_size)
37+
cxs = torch.zeros(batch_size, cell_size) if rnn_type == 'lstm' else None
38+
39+
if rnn_type == 'lstm':
40+
cell = torch.nn.LSTMCell(cell_size, cell_size)
41+
elif rnn_type == 'gru':
42+
cell = torch.nn.GRUCell(cell_size, cell_size)
43+
else:
44+
raise ValueError(rnn_type)
45+
args['cell'] = cell
46+
47+
module = importlib.import_module('neural_sp.models.modules.zoneout')
48+
zoneout_cell = module.ZoneoutCell(**args)
49+
50+
if rnn_type == 'lstm':
51+
h, c = zoneout_cell(xs, (hxs, cxs))
52+
assert h.size() == (batch_size, cell_size)
53+
assert c.size() == (batch_size, cell_size)
54+
elif rnn_type == 'gru':
55+
h = zoneout_cell(xs, hxs)
56+
assert h.size() == (batch_size, cell_size)
57+
else:
58+
raise ValueError(rnn_type)

0 commit comments

Comments
 (0)