Skip to content

Commit

Permalink
add train_metrics for Per-FedAvg
Browse files Browse the repository at this point in the history
  • Loading branch information
TsingZ0 committed Apr 18, 2023
1 parent 68e5acf commit d00b7b3
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
47 changes: 46 additions & 1 deletion system/flcore/clients/clientperavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,49 @@ def load_train_data_one_step(self, batch_size=None):
if batch_size == None:
batch_size = self.batch_size
train_data = read_client_data(self.dataset, self.id, is_train=True)
return DataLoader(train_data, batch_size, drop_last=True, shuffle=False)
return DataLoader(train_data, batch_size, drop_last=True, shuffle=False)


def train_metrics(self, model=None):
trainloader = self.load_train_data(self.batch_size*2)
if model == None:
model = self.model
model.eval()

train_num = 0
losses = 0
for X, Y in trainloader:
# step 1
if type(X) == type([]):
x = [None, None]
x[0] = X[0][:self.batch_size].to(self.device)
x[1] = X[1][:self.batch_size]
else:
x = X[:self.batch_size].to(self.device)
y = Y[:self.batch_size].to(self.device)
if self.train_slow:
time.sleep(0.1 * np.abs(np.random.rand()))
self.optimizer.zero_grad()
output = self.model(x)
loss = self.loss(output, y)
loss.backward()
self.optimizer.step()

# step 2
if type(X) == type([]):
x = [None, None]
x[0] = X[0][self.batch_size:].to(self.device)
x[1] = X[1][self.batch_size:]
else:
x = X[self.batch_size:].to(self.device)
y = Y[self.batch_size:].to(self.device)
if self.train_slow:
time.sleep(0.1 * np.abs(np.random.rand()))
self.optimizer.zero_grad()
output = self.model(x)
loss1 = self.loss(output, y)

train_num += y.shape[0]
losses += loss1.item() * y.shape[0]

return losses, train_num
39 changes: 38 additions & 1 deletion system/flcore/servers/serverperavg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import torch
import numpy as np
from flcore.clients.clientperavg import clientPerAvg
from flcore.servers.serverbase import Server
from threading import Thread
Expand Down Expand Up @@ -68,4 +69,40 @@ def evaluate_one_step(self):
test_acc = sum(stats[2])*1.0 / sum(stats[1])

self.rs_test_acc.append(test_acc)
print("Average Test Accurancy: {:.4f}".format(test_acc))
print("Average Test Accurancy: {:.4f}".format(test_acc))


def evaluate_one_step(self, acc=None, loss=None):
models_temp = []
for c in self.clients:
models_temp.append(copy.deepcopy(c.model))
c.train_one_step()
stats = self.test_metrics()
# set the local model back on clients for training process
for i, c in enumerate(self.clients):
c.clone_model(models_temp[i], c.model)

stats_train = self.train_metrics()
# set the local model back on clients for training process
for i, c in enumerate(self.clients):
c.clone_model(models_temp[i], c.model)

accs = [a / n for a, n in zip(stats[2], stats[1])]

test_acc = sum(stats[2])*1.0 / sum(stats[1])
train_loss = sum(stats_train[2])*1.0 / sum(stats_train[1])

if acc == None:
self.rs_test_acc.append(test_acc)
else:
acc.append(test_acc)

if loss == None:
self.rs_train_loss.append(train_loss)
else:
loss.append(train_loss)

print("Averaged Train Loss: {:.4f}".format(train_loss))
print("Averaged Test Accurancy: {:.4f}".format(test_acc))
# self.print_(test_acc, train_acc, train_loss)
print("Std Test Accurancy: {:.4f}".format(np.std(accs)))

1 comment on commit d00b7b3

@youngfish42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FedProto 和 pFedMe 这两个算法好像也只报告了部分结果

Please sign in to comment.