From 85887ad93bf0f9d7f589aa1e6644b0aeec4a1316 Mon Sep 17 00:00:00 2001 From: Dongyang Jin <73057174+jdyjjj@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:18:34 +0800 Subject: [PATCH] Update multigait++.py --- opengait/modeling/models/multigait++.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/opengait/modeling/models/multigait++.py b/opengait/modeling/models/multigait++.py index cbc3cb5..9fe9deb 100644 --- a/opengait/modeling/models/multigait++.py +++ b/opengait/modeling/models/multigait++.py @@ -54,7 +54,7 @@ def build_network(self, model_cfg): self.part2_layer3 = copy.deepcopy(self.part1_layer3) self.layer3 = copy.deepcopy(self.part1_layer3) self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d') - self.crossattn1 = CrossAttention(64) + self.csquare = CSquare(64) self.FCs = SeparateFCs(16, 256*C, 128*C) @@ -101,7 +101,7 @@ def forward(self, inputs): part2 = self.part2_layer1(part2) part1 = self.part1_layer0(part1) part1 = self.part1_layer1(part1) - out, attn1, attn2, attn_co = self.crossattn1(part2,part1) + out, attn1, attn2, attn_co = self.csquare(part2,part1) part2 = self.part2_layer2(part2*attn1) part1 = self.part1_layer2(part1*attn2) @@ -157,9 +157,9 @@ def forward(self, feat_list): return retun -class CrossAttention(nn.Module): +class CSquare(nn.Module): def __init__(self, in_channels=64, squeeze_ratio=16, h=32, w=22): - super(CrossAttention, self).__init__() + super(CSquare, self).__init__() hidden_dim = int(in_channels / squeeze_ratio) self.TP_mean = PackSequenceWrapper(torch.mean) self.conv2 = SetBlockWrapper(nn.Sequential(