Skip to content

Commit

Permalink
Fix temperature annealing
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoli-ai authored Sep 21, 2022
1 parent 7fb9a9b commit 60a4691
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, us
for batch_idx, (inputs, targets) in enumerate(train_loader):
# update temperature of ODConv
if epoch < args.temp_epoch and hasattr(model.module, 'net_update_temperature'):
temp = get_temperature(batch_idx, epoch, train_loader_len,
temp = get_temperature(batch_idx + 1, epoch, train_loader_len,
temp_epoch=args.temp_epoch, temp_init=args.temp_init)
model.module.net_update_temperature(temp)

Expand Down

0 comments on commit 60a4691

Please sign in to comment.