-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtest_no_grad.py
39 lines (29 loc) · 1003 Bytes
/
test_no_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import List
import heavyball
import heavyball.utils
import pytest
import torch
from benchmark.utils import get_optim
from heavyball.utils import set_torch, clean
from torch import nn
class Param(nn.Module):
def __init__(self, size):
super().__init__()
self.weight = nn.Parameter(torch.randn(size))
def forward(self, inp):
return self.weight.mean() * inp
@pytest.mark.parametrize("opt", heavyball.__all__)
@pytest.mark.parametrize("size", [(4, 4, 4, 4), ])
def test_closre(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3):
clean()
set_torch()
opt = getattr(heavyball, opt)
for _ in range(outer_iterations):
clean()
model = nn.Sequential(*[Param(size) for _ in range(depth)]).cuda()
o = get_optim(opt, model.parameters(), lr=1e-3)
for i in range(iterations):
o.step()
o.zero_grad()
assert o.state_size() == 0
del model, o