Skip to content

Commit

Permalink
update shape information in test_activation
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 19, 2024
1 parent fe46f5c commit 6c050f9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_HouseHolder_theta_zero(dense):
bs = np.random.randint(32, 128)
h, w = np.random.randint(1, 64), np.random.randint(1, 64)
c = np.random.randint(1, 64) * 2
size = uft.to_framework_channel(bs, c // 2, h, w)
size = (bs,) + uft.to_framework_channel((c // 2, h, w))
ch = c

hh = uft.get_instance_framework(
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_HouseHolder_theta_pi(dense):
bs = np.random.randint(32, 128)
h, w = np.random.randint(1, 64), np.random.randint(1, 64)
c = np.random.randint(1, 64) * 2
size = uft.to_framework_channel(bs, c // 2, h, w)
size = (bs,) + uft.to_framework_channel((c // 2, h, w))
ch = c

hh = uft.get_instance_framework(
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_HouseHolder_theta_90(dense):
bs = np.random.randint(32, 128)
h, w = np.random.randint(1, 64), np.random.randint(1, 64)
c = np.random.randint(1, 64) * 2
size = uft.to_framework_channel(bs, c // 2, h, w)
size = (bs,) + uft.to_framework_channel((c // 2, h, w))
ch = c

hh = uft.get_instance_framework(HouseHolder, {"channels": ch})
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_HouseHolder_idempotence():
hh = uft.get_instance_framework(
HouseHolder, {"channels": c, "theta_initializer": "normal"}
)
x = np.random.normal(size=uft.to_framework_channel((bs, c, h, w)))
x = np.random.normal(size=(bs,) + uft.to_framework_channel((c, h, w)))
x = uft.to_tensor(x)

# Run two times the HH activation and compare both outputs
Expand Down

0 comments on commit 6c050f9

Please sign in to comment.