Skip to content

Commit

Permalink
Update multigait++.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jdyjjj authored Feb 12, 2025
1 parent bc9b17e commit 85887ad
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions opengait/modeling/models/multigait++.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_network(self, model_cfg):
self.part2_layer3 = copy.deepcopy(self.part1_layer3)
self.layer3 = copy.deepcopy(self.part1_layer3)
self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d')
self.crossattn1 = CrossAttention(64)
self.csquare = CSquare(64)


self.FCs = SeparateFCs(16, 256*C, 128*C)
Expand Down Expand Up @@ -101,7 +101,7 @@ def forward(self, inputs):
part2 = self.part2_layer1(part2)
part1 = self.part1_layer0(part1)
part1 = self.part1_layer1(part1)
out, attn1, attn2, attn_co = self.crossattn1(part2,part1)
out, attn1, attn2, attn_co = self.csquare(part2,part1)

part2 = self.part2_layer2(part2*attn1)
part1 = self.part1_layer2(part1*attn2)
Expand Down Expand Up @@ -157,9 +157,9 @@ def forward(self, feat_list):
return retun


class CrossAttention(nn.Module):
class CSquare(nn.Module):
def __init__(self, in_channels=64, squeeze_ratio=16, h=32, w=22):
super(CrossAttention, self).__init__()
super(CSquare, self).__init__()
hidden_dim = int(in_channels / squeeze_ratio)
self.TP_mean = PackSequenceWrapper(torch.mean)
self.conv2 = SetBlockWrapper(nn.Sequential(
Expand Down

0 comments on commit 85887ad

Please sign in to comment.