|
| 1 | +import os |
| 2 | +import math |
| 3 | +import copy |
| 4 | +import time |
| 5 | +import torch |
| 6 | +import random |
| 7 | +import numpy as np |
| 8 | +from tqdm import tqdm |
| 9 | +from csv import writer |
| 10 | +from update import LocalUpdate, test_inference |
| 11 | +from models.models import CNNMnistSmall, VGG, RNNModel |
| 12 | +from utils.utils_training import get_dataset, average_weights, add_delta_weights |
| 13 | + |
| 14 | +from user import UserType |
| 15 | +import json |
| 16 | + |
| 17 | +with open('config.json') as config_file: |
| 18 | + config = json.load(config_file) |
| 19 | + |
| 20 | + |
| 21 | +def main(dataset, payload, num_users, frac, run_name): |
| 22 | + result_folder_tree = os.path.join(os.getcwd(), run_name, dataset, |
| 23 | + payload, str(num_users), str(int(frac * 100))) |
| 24 | + |
| 25 | + if not os.path.exists(result_folder_tree): |
| 26 | + os.makedirs(result_folder_tree) |
| 27 | + os.makedirs(os.path.join(result_folder_tree, "models")) |
| 28 | + os.makedirs(os.path.join(result_folder_tree, "payloads")) |
| 29 | + |
| 30 | + device = 'cuda' if config["gpu"] else 'cpu' |
| 31 | + # load ldpc matrixes |
| 32 | + H, G, enc_length, preamble1, global_model = None, None, None, None, None |
| 33 | + |
| 34 | + # load datasets |
| 35 | + train_dataset, test_dataset, user_groups, ntokens = get_dataset(dataset, config["iid"], config["unequal"], |
| 36 | + num_users) |
| 37 | + # BUILD MODEL |
| 38 | + if dataset == 'mnist': |
| 39 | + global_model = CNNMnistSmall() |
| 40 | + elif dataset == 'cifar10': |
| 41 | + global_model = VGG('VGG11') |
| 42 | + elif dataset == "wiki": |
| 43 | + global_model = RNNModel("LSTM", ntokens, 200, 200, 2, 0.2, True) |
| 44 | + |
| 45 | + error_correction = config["error_correction"] |
| 46 | + stealthiness_level = config["stealthy"] |
| 47 | + |
| 48 | + if not global_model: |
| 49 | + print('Configuration Error!') |
| 50 | + global_model.to(device) |
| 51 | + |
| 52 | + # Training |
| 53 | + epoch = 0 |
| 54 | + # Injections |
| 55 | + injections = 0 |
| 56 | + # Check when we can start the decoding |
| 57 | + payload_alive = False |
| 58 | + |
| 59 | + # Define the number of sender users |
| 60 | + m_comp = max(int(config["senders"] * num_users), 1) |
| 61 | + sender_users = np.random.choice(range(num_users), m_comp, replace=False) |
| 62 | + |
| 63 | + for user in user_groups: |
| 64 | + if user.user_id in sender_users: |
| 65 | + user.user_type = UserType.SENDER |
| 66 | + |
| 67 | + with tqdm(range(config["epochs"])) as bar: |
| 68 | + for _ in bar: |
| 69 | + local_weights, local_losses = [], [] |
| 70 | + m = max(int(frac * num_users), 1) |
| 71 | + np.random.seed(random.randint(100, 1000)) |
| 72 | + idxs_users = np.random.choice(range(num_users), m, replace=False) |
| 73 | + |
| 74 | + global_model.train() |
| 75 | + |
| 76 | + for idx in idxs_users: |
| 77 | + user = user_groups[idx] |
| 78 | + local_model = LocalUpdate(gpu=config["gpu"], dataset=train_dataset, idxs=user.data, |
| 79 | + local_bs=config["local_bs"], dataset_name=dataset) |
| 80 | + |
| 81 | + w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch, |
| 82 | + optimizer=config["optimizer"], lr=config["lr"], |
| 83 | + local_ep=config["local_ep"]) |
| 84 | + |
| 85 | + bar.set_postfix({'Loss': {copy.deepcopy(loss)}}) |
| 86 | + |
| 87 | + if config["store_global"] <= epoch < config["injection"]: |
| 88 | + user.global_model = copy.deepcopy(global_model) |
| 89 | + user.previous_round = epoch |
| 90 | + |
| 91 | + if user.user_type == UserType.SENDER and epoch >= config["injection"]: |
| 92 | + if enc_length is None: |
| 93 | + user.global_model = copy.deepcopy(global_model) |
| 94 | + |
| 95 | + payload_alive = user.extract_payload(copy.deepcopy(global_model), |
| 96 | + payload, |
| 97 | + result_folder_tree, |
| 98 | + enc_length, |
| 99 | + H, G, |
| 100 | + preamble1, |
| 101 | + error_correction) |
| 102 | + |
| 103 | + sender_weights, enc_length, H, G, preamble1 = user.inject_payload(copy.deepcopy(w), |
| 104 | + device, |
| 105 | + payload, |
| 106 | + stealthiness_level, |
| 107 | + error_correction) |
| 108 | + |
| 109 | + local_weights.append(copy.deepcopy(sender_weights)) |
| 110 | + injections += 1 |
| 111 | + else: |
| 112 | + local_weights.append(copy.deepcopy(w.state_dict())) |
| 113 | + local_losses.append(copy.deepcopy(loss)) |
| 114 | + |
| 115 | + if epoch >= config["injection"] and payload_alive: |
| 116 | + for idx in idxs_users: |
| 117 | + user = user_groups[idx] |
| 118 | + user.extract_payload(copy.deepcopy(global_model), |
| 119 | + payload, |
| 120 | + result_folder_tree, |
| 121 | + enc_length, |
| 122 | + H, G, |
| 123 | + preamble1, |
| 124 | + error_correction) |
| 125 | + |
| 126 | + global_weights_delta = average_weights(local_weights) |
| 127 | + global_weights = add_delta_weights(copy.deepcopy(global_model), global_weights_delta) |
| 128 | + global_model.load_state_dict(global_weights) |
| 129 | + |
| 130 | + if epoch % 5 == 0: |
| 131 | + if dataset == "wiki": |
| 132 | + train_loss = test_inference(config["gpu"], copy.deepcopy(global_model), train_dataset, dataset) |
| 133 | + with open(os.path.join(result_folder_tree, "acc_loss.csv"), 'a+') as fp: |
| 134 | + writer_object = writer(fp) |
| 135 | + writer_object.writerow([epoch, train_loss, math.exp(train_loss)]) |
| 136 | + fp.close() |
| 137 | + else: |
| 138 | + train_acc, train_loss = test_inference(config["gpu"], copy.deepcopy(global_model), train_dataset, |
| 139 | + dataset) |
| 140 | + test_acc, test_loss = test_inference(config["gpu"], copy.deepcopy(global_model), test_dataset, |
| 141 | + dataset) |
| 142 | + |
| 143 | + with open(os.path.join(result_folder_tree, "acc_loss.csv"), 'a+') as fp: |
| 144 | + writer_object = writer(fp) |
| 145 | + writer_object.writerow([epoch, train_acc, train_loss, test_acc, test_loss]) |
| 146 | + fp.close() |
| 147 | + |
| 148 | + rnd_coverage = sum( |
| 149 | + [1 if u.correctly_extracted and not u.user_type == UserType.SENDER else 0 for u in user_groups]) |
| 150 | + with open(os.path.join(result_folder_tree, "coverage.csv"), 'a+') as fp: |
| 151 | + writer_object = writer(fp) |
| 152 | + writer_object.writerow([epoch, rnd_coverage]) |
| 153 | + fp.close() |
| 154 | + torch.save(global_model.state_dict(), |
| 155 | + os.path.join(result_folder_tree, "models", f"checkpoint.epoch{epoch}.pt")) |
| 156 | + epoch += 1 |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == '__main__': |
| 160 | + start_time = time.time() |
| 161 | + for p in config["payload"]: |
| 162 | + for n in config["num_users"]: |
| 163 | + for f in config["frac"]: |
| 164 | + main(config["dataset"], |
| 165 | + p, # payload |
| 166 | + n, # num_users |
| 167 | + f, # frac |
| 168 | + config["run_name"] |
| 169 | + ) |
| 170 | + print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time)) |
0 commit comments