基于pytorch的联邦学习,使用MNIST数据集
import math
from copy import deepcopy
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class MConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super(MConv2d, self).__init__(in_channels, out_channels, kernel_size,
stride, padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
self.agg_data = 0
@torch.no_grad()
def download(self, glb_conv):
self.weight.data[:] = glb_conv.weight.data
self.bias.data[:] = glb_conv.bias.data
@torch.no_grad()
def aggregate(self, oth_conv, oth_data):
"""This method is ONLY called by the global model!!
"""
loc_weight, loc_bias = oth_conv.weight.data, oth_conv.bias.data
all_data = self.agg_data + oth_data
self.weight.data.mul_(self.agg_data / all_data)
self.weight.data.add_(loc_weight * (oth_data / all_data))
if self.bias is not None:
self.bias.data.mul_(self.agg_data / all_data)
self.bias.data.add_(loc_bias * (oth_data / all_data))
self.agg_data = all_data
def zero_agg_factor(self):
"""This method is ONLY called by the global model!!
"""
self.agg_data = 0
class MLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(MLinear, self).__init__(in_features, out_features, bias)
self.agg_data = 0
@torch.no_grad()
def download(self, glb_lin):
self.weight.data[:] = glb_lin.weight.data
self.bias.data[:] = glb_lin.bias.data
@torch.no_grad()
def aggregate(self, oth_lin, oth_data):
"""This method is ONLY called by the global model!!
"""
loc_weight, loc_bias = oth_lin.weight.data, oth_lin.bias.data
all_data = self.agg_data + oth_data
self.weight.data.mul_(self.agg_data / all_data)
self.weight.data.add_(loc_weight * (oth_data / all_data))
if self.bias is not None:
self.bias.data.mul_(self.agg_data / all_data)
self.bias.data.add_(loc_bias * (oth_data / all_data))
self.agg_data = all_data
def zero_agg_factor(self):
"""This method is ONLY called by the global model!!
"""
self.agg_data = 0
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self, alpha=1):
super(Net, self).__init__()
self.conv1 = MConv2d(1, 8, 3, 1, 1)
self.fc1 = MLinear(14*14*8, 10)
self.net_module_list = [self.conv1, self.fc1]
@torch.no_grad()
def download(self, glb_net):
for glb_module, loc_module in zip(glb_net.net_module_list, self.net_module_list):
loc_module.download(glb_module)
@torch.no_grad()
def aggregate(self, loc_net, loc_data_num):
"""This method CAN ONLY called by the global model!!
"""
for glb_module, loc_module in zip(self.net_module_list, loc_net.net_module_list):
glb_module.aggregate(loc_module, loc_data_num)
def zero_agg_factor(self):
for module in self.net_module_list:
module.zero_agg_factor()
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = torch.flatten(x, 1)
x = self.fc1(x)
return x
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class.
"""
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = [int(i) for i in idxs]
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image.clone().detach(), torch.tensor(label)
# Independent Identically Distributed(IID)
def get_iid_data(dataset, num_users): # get IID-distribution data
num_items = int(len(dataset)/num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items,
replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
def get_noniid_data(dataset, num_users, alpha):# get non-IID-distribution data
np.random.seed(0)
"""
dataset: training set of CIFAR
大致的做法是:
(1)不对图片数据进行划分,而是划分图片数据的索引。这样大家共用一个图片数据存储空间
(2)按数据集每一类数据分别划分给用户,使用狄利克雷分布来确定用户间划分的比例
(3)设置了欧皇非酋保底机制,让这个non-iid数据划分不至于太极端
"""
dict_users = {}
num_classes = 10
min_size = 0
labels = np.array(dataset.targets) # 数据集的labels
num_items = int(len(dataset)/num_users)
while min_size < 10:
idx_groups = [[] for _ in range(num_users)] # 预先做好每个客户用于存储自己数据集索引的列表,函数要return的就是它
# for each class in the dataset
for k in range(num_classes): # 遍历所有的标签,当前标签为k
idx_k = np.where(labels == k)[0] # 取出数据集的labels为当前标签k的索引值列表idx_k
np.random.shuffle(idx_k) # 索引值列表idx_k
proportions = np.random.dirichlet(np.repeat(alpha, num_users))# 生成狄利克雷分布
# Balance
proportions = np.array(
[p*(len(user_idx) < num_items) for p, user_idx in zip(proportions, idx_groups)]
)# 保底机制,
# (len(user_idx) < num_items)为True或者False,即为1或0
proportions = proportions / proportions.sum() # 重新将置0比例向量重新归一化
proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] # 这一句是为了np.split函数的输入参数而做
idx_groups = [user_idx + idx.tolist() for user_idx, idx in zip(idx_groups, np.split(idx_k, proportions))]# 将图片索引分到idx_groups
min_size = min([len(user_idx) for user_idx in idx_groups]) # 保底机制,
for i in range(num_users):
np.random.shuffle(idx_groups[i])
dict_users[i] = idx_groups[i]
return dict_users
def get_dataset_MNIST(data_dir, num_users, iid=True, alpha=0.5):
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]);
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=transform_train)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=transform_test)
if iid:
user_groups = get_iid_data(train_dataset, num_users)
else:
user_groups = get_noniid_data(train_dataset, num_users, alpha)
return train_dataset, test_dataset, user_groups
import os
import copy
import random as rn
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pickle
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
class Client:
train_set = None
test_set = None
local_epoch = None
batch_size = None
lr = None
lr_decay = 0.996
device = 'cuda'
criterion = nn.CrossEntropyLoss()
def __init__(self, train_idx, loc_model, client_id, test_dataset):
self.trainloader = DataLoader(DatasetSplit(self.train_set, train_idx),
batch_size=self.batch_size, shuffle=True)
self.testloader = DataLoader(test_dataset, batch_size=100, shuffle=False)
self.client_id = client_id
self.num_data = len(train_idx)
self.loc_model = loc_model
def local_update(self, cr):
self.loc_model.to(self.device)
lr = self.lr * (self.lr_decay ** cr)
optimizer = torch.optim.SGD(self.loc_model.parameters(), lr=lr, momentum=0.9)
self.loc_model.train()
for ep in range(self.local_epoch):
for batch_idx, (images, labels) in enumerate(self.trainloader):
images, labels = images.to(self.device), labels.to(self.device)
self.loc_model.zero_grad()
output = self.loc_model(images)
loss = self.criterion(output, labels)
loss.backward()
optimizer.step()
def download(self, glb_model):
self.loc_model.download(glb_model)
def eval_test(self):
test_loss = 0
correct = 0
total = 0
self.loc_model.eval()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(self.testloader):
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.loc_model(images)
loss = self.criterion(outputs, labels)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100.*correct/total
test_loss /= total
print('\n client id = %d, Test acc is %.4f'%(self.client_id , acc))
return acc, test_loss
class Server:
def __init__(self, clients, test_dataset):
self.clients = clients
self.global_model = Net()
self.testloader = DataLoader(test_dataset, batch_size=100, shuffle=False)
self.criterion = nn.CrossEntropyLoss()
self.device = 'cuda'
def distribute(self):
self.global_model.zero_agg_factor()
for c in self.clients:
c.download(self.global_model)
def model_update(self, cr):
for c in self.clients:
c.local_update(cr)
print("\rComm round %d: updating in client %d " % (cr, c.client_id), end='')
def agg(self, cr, info):
idx_list = list(range(len(self.clients)))
self.global_model.zero_agg_factor()
self.global_model.to(self.device)
for c_idx in idx_list:
self.global_model.aggregate(self.clients[c_idx].loc_model, self.clients[c_idx].num_data)
print("\rComm round %d: aggregated client %d " % (cr, c_idx), end='')
def eval_test(self):
test_loss = 0
correct = 0
total = 0
self.global_model.eval()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(self.testloader):
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.global_model(images)[:5]
loss = self.criterion(outputs, labels)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100.*correct/total
test_loss /= total
print('\n Test acc is ', acc)
return acc, test_loss
def fedlearn(self, total_cr, info):
acc_list = []
test_loss_list = []
for cr in range(total_cr):
self.distribute() # 服务器分发模型到各个客户端
self.model_update(cr) # 客户端本地更新模型
self.agg(cr, info) # 客户端上传模型,在服务器端聚合
acc, test_loss = self.eval_test() # 测试联邦学习的全局模型精度
acc_list.append(acc)
test_loss_list.append(test_loss)
def main():
seed = 0
rn.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
num_users = 100
data_dir = './dataset/MNIST'
# IID data distribution!!!!!!
train_dataset, test_dataset, dict_users = get_dataset_MNIST(data_dir, num_users, iid=False)
Client.train_set = train_dataset
Client.test_set = test_dataset
clients = []
epoch, batch_size, lr, cr = 2, 64, 0.05, 10
for ir in range(1,2): # 做实验一般跑几遍取平均值,这里只跑一次
info = (epoch, batch_size, lr, ir)
print("***********************************************************************")
print("******* ir = %d, epoch=%d, batch_size=%d, lr=%.4f starting ***********"%(info[3], info[0],info[1],info[2])) # 超参设定
print("***********************************************************************")
Client.local_epoch = epoch
Client.batch_size = batch_size
Client.lr = lr
for user_idx in range(num_users):
clients.append(Client(dict_users[user_idx], Net(), user_idx, test_dataset))
for i in range(num_users):
print("clients[%d].num_data = "%(i), clients[i].num_data)
server = Server(clients, test_dataset)
server.fedlearn(cr, info)
clients.clear()
main()
***********************************************************************
******* ir = 1, epoch=2, batch_size=64, lr=0.0500 starting ***********
***********************************************************************
clients[0].num_data = 534
clients[1].num_data = 787
clients[2].num_data = 921
clients[3].num_data = 664
clients[4].num_data = 647
clients[5].num_data = 693
clients[6].num_data = 456
clients[7].num_data = 290
clients[8].num_data = 659
clients[9].num_data = 632
clients[10].num_data = 787
clients[11].num_data = 706
clients[12].num_data = 544
clients[13].num_data = 587
clients[14].num_data = 667
clients[15].num_data = 711
clients[16].num_data = 865
clients[17].num_data = 577
clients[18].num_data = 620
clients[19].num_data = 594
clients[20].num_data = 545
clients[21].num_data = 510
clients[22].num_data = 528
clients[23].num_data = 1263
clients[24].num_data = 600
clients[25].num_data = 633
clients[26].num_data = 718
clients[27].num_data = 1180
clients[28].num_data = 724
clients[29].num_data = 625
clients[30].num_data = 688
clients[31].num_data = 465
clients[32].num_data = 296
clients[33].num_data = 753
clients[34].num_data = 974
clients[35].num_data = 221
clients[36].num_data = 416
clients[37].num_data = 519
clients[38].num_data = 499
clients[39].num_data = 382
clients[40].num_data = 159
clients[41].num_data = 609
clients[42].num_data = 469
clients[43].num_data = 453
clients[44].num_data = 256
clients[45].num_data = 857
clients[46].num_data = 554
clients[47].num_data = 655
clients[48].num_data = 354
clients[49].num_data = 223
clients[50].num_data = 361
clients[51].num_data = 774
clients[52].num_data = 329
clients[53].num_data = 287
clients[54].num_data = 728
clients[55].num_data = 814
clients[56].num_data = 678
clients[57].num_data = 610
clients[58].num_data = 973
clients[59].num_data = 360
clients[60].num_data = 1262
clients[61].num_data = 680
clients[62].num_data = 464
clients[63].num_data = 698
clients[64].num_data = 778
clients[65].num_data = 609
clients[66].num_data = 498
clients[67].num_data = 704
clients[68].num_data = 212
clients[69].num_data = 95
clients[70].num_data = 563
clients[71].num_data = 802
clients[72].num_data = 805
clients[73].num_data = 589
clients[74].num_data = 698
clients[75].num_data = 439
clients[76].num_data = 542
clients[77].num_data = 362
clients[78].num_data = 1073
clients[79].num_data = 654
clients[80].num_data = 624
clients[81].num_data = 715
clients[82].num_data = 629
clients[83].num_data = 642
clients[84].num_data = 308
clients[85].num_data = 370
clients[86].num_data = 563
clients[87].num_data = 334
clients[88].num_data = 513
clients[89].num_data = 1024
clients[90].num_data = 640
clients[91].num_data = 349
clients[92].num_data = 742
clients[93].num_data = 443
clients[94].num_data = 782
clients[95].num_data = 881
clients[96].num_data = 453
clients[97].num_data = 354
clients[98].num_data = 617
clients[99].num_data = 507
Comm round 0: aggregated client 99
Test acc is 71.01
Comm round 1: aggregated client 99
Test acc is 83.05
Comm round 2: aggregated client 99
Test acc is 91.71
Comm round 3: aggregated client 99
Test acc is 91.13
Comm round 4: aggregated client 99
Test acc is 93.64
Comm round 5: aggregated client 99
Test acc is 93.88
Comm round 6: aggregated client 99
Test acc is 94.51
Comm round 7: aggregated client 99
Test acc is 94.78
Comm round 8: aggregated client 99
Test acc is 95.06
Comm round 9: aggregated client 99
Test acc is 95.26
如何使前五个客户端在本地训练呢