ocklen 2021-11-28 19:20 采纳率: 0%
浏览 19
已结题

如何在联邦学习中,使前五个客户端不参与联邦学习,而在自身本地训练呢

基于pytorch的联邦学习,使用MNIST数据集

import math
from copy import deepcopy
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

class MConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(MConv2d, self).__init__(in_channels, out_channels, kernel_size, 
                                      stride, padding, dilation=dilation, 
                                      groups=groups, bias=bias, padding_mode=padding_mode)
        self.agg_data = 0
    
    @torch.no_grad()
    def download(self, glb_conv):
        self.weight.data[:] = glb_conv.weight.data
        self.bias.data[:] = glb_conv.bias.data
           
    @torch.no_grad()
    def aggregate(self, oth_conv, oth_data):
        """This method is ONLY called by the global model!!
        """
        loc_weight, loc_bias = oth_conv.weight.data, oth_conv.bias.data
        all_data = self.agg_data + oth_data
        self.weight.data.mul_(self.agg_data / all_data)
        self.weight.data.add_(loc_weight * (oth_data / all_data))
        if self.bias is not None:
            self.bias.data.mul_(self.agg_data / all_data)
            self.bias.data.add_(loc_bias * (oth_data / all_data))
        self.agg_data = all_data
    
    def zero_agg_factor(self):
        """This method is ONLY called by the global model!!
        """
        self.agg_data = 0
class MLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MLinear, self).__init__(in_features, out_features, bias)
        self.agg_data = 0

    @torch.no_grad()
    def download(self, glb_lin):
        self.weight.data[:] = glb_lin.weight.data
        self.bias.data[:] = glb_lin.bias.data
    
    @torch.no_grad()
    def aggregate(self, oth_lin, oth_data):
        """This method is ONLY called by the global model!!
        """
        loc_weight, loc_bias = oth_lin.weight.data, oth_lin.bias.data
        all_data = self.agg_data + oth_data
        self.weight.data.mul_(self.agg_data / all_data)
        self.weight.data.add_(loc_weight * (oth_data / all_data))
        if self.bias is not None:
            self.bias.data.mul_(self.agg_data / all_data)
            self.bias.data.add_(loc_bias * (oth_data / all_data))
        self.agg_data = all_data
    
    def zero_agg_factor(self):
        """This method is ONLY called by the global model!!
        """
        self.agg_data = 0
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, alpha=1):
        super(Net, self).__init__()
        self.conv1 = MConv2d(1, 8, 3, 1, 1)
        self.fc1   = MLinear(14*14*8, 10)
        self.net_module_list = [self.conv1, self.fc1]

    
    @torch.no_grad()
    def download(self, glb_net):
        for glb_module, loc_module in zip(glb_net.net_module_list, self.net_module_list):
            loc_module.download(glb_module)
            
    @torch.no_grad()
    def aggregate(self, loc_net, loc_data_num):
        """This method CAN ONLY called by the global model!!
        """
        for glb_module, loc_module in zip(self.net_module_list, loc_net.net_module_list):
            glb_module.aggregate(loc_module, loc_data_num)
    
    def zero_agg_factor(self):
        for module in self.net_module_list:
            module.zero_agg_factor()
    

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image.clone().detach(), torch.tensor(label)

# Independent Identically Distributed(IID)
def get_iid_data(dataset, num_users): # get IID-distribution data 
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

def get_noniid_data(dataset, num_users, alpha):# get non-IID-distribution data 
    np.random.seed(0)
    """
    dataset: training set of CIFAR
    大致的做法是:
        (1)不对图片数据进行划分,而是划分图片数据的索引。这样大家共用一个图片数据存储空间
        (2)按数据集每一类数据分别划分给用户,使用狄利克雷分布来确定用户间划分的比例
        (3)设置了欧皇非酋保底机制,让这个non-iid数据划分不至于太极端
    """
    dict_users = {}
    num_classes = 10
    min_size = 0
    labels = np.array(dataset.targets)              # 数据集的labels
    num_items = int(len(dataset)/num_users)

    while min_size < 10:
        idx_groups = [[] for _ in range(num_users)] # 预先做好每个客户用于存储自己数据集索引的列表,函数要return的就是它
        # for each class in the dataset
        for k in range(num_classes):                # 遍历所有的标签,当前标签为k
            idx_k = np.where(labels == k)[0]        # 取出数据集的labels为当前标签k的索引值列表idx_k
            np.random.shuffle(idx_k)                # 索引值列表idx_k
            proportions = np.random.dirichlet(np.repeat(alpha, num_users))# 生成狄利克雷分布
                                                                  
                 
            # Balance
            proportions = np.array(
                [p*(len(user_idx) < num_items) for p, user_idx in zip(proportions, idx_groups)]
                )#  保底机制,
                 # (len(user_idx) < num_items)为True或者False,即为1或0

            proportions = proportions / proportions.sum() # 重新将置0比例向量重新归一化
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] # 这一句是为了np.split函数的输入参数而做
            idx_groups = [user_idx + idx.tolist() for user_idx, idx in zip(idx_groups, np.split(idx_k, proportions))]# 将图片索引分到idx_groups
        min_size = min([len(user_idx) for user_idx in idx_groups]) # 保底机制,
    for i in range(num_users):
        np.random.shuffle(idx_groups[i])
        dict_users[i] = idx_groups[i]
    return dict_users

def get_dataset_MNIST(data_dir, num_users, iid=True, alpha=0.5):
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]);
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                     transform=transform_train)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                    transform=transform_test)
    
    if iid:
        user_groups = get_iid_data(train_dataset, num_users)
    else:
        user_groups = get_noniid_data(train_dataset, num_users, alpha)

    return train_dataset, test_dataset, user_groups

import os
import copy
import random as rn
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pickle


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ["CUDA_VISIBLE_DEVICES"] = '0'


class Client:
    train_set = None
    test_set = None
    local_epoch = None
    batch_size = None
    lr = None
    lr_decay = 0.996
    device = 'cuda'
    criterion = nn.CrossEntropyLoss()

    def __init__(self, train_idx, loc_model, client_id, test_dataset):
        self.trainloader = DataLoader(DatasetSplit(self.train_set, train_idx), 
                                      batch_size=self.batch_size, shuffle=True)
        self.testloader = DataLoader(test_dataset, batch_size=100, shuffle=False)
        self.client_id = client_id
        self.num_data = len(train_idx)
        self.loc_model = loc_model

    def local_update(self, cr):
        self.loc_model.to(self.device)
        lr = self.lr * (self.lr_decay ** cr)
        optimizer = torch.optim.SGD(self.loc_model.parameters(), lr=lr, momentum=0.9)
        self.loc_model.train()
        for ep in range(self.local_epoch):
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                self.loc_model.zero_grad()
                output = self.loc_model(images)
                loss = self.criterion(output, labels)
                loss.backward()
                optimizer.step()
    
    def download(self, glb_model):
        self.loc_model.download(glb_model)
    
    def eval_test(self):
        test_loss = 0
        correct = 0
        total = 0
        self.loc_model.eval()
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(self.testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.loc_model(images)
                loss = self.criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        acc = 100.*correct/total
        test_loss /= total
        print('\n client id = %d, Test acc is %.4f'%(self.client_id , acc))
        return acc, test_loss
    
    
    

class Server:
    def __init__(self, clients, test_dataset):
        self.clients = clients
        self.global_model = Net()
        self.testloader = DataLoader(test_dataset, batch_size=100, shuffle=False)
        self.criterion = nn.CrossEntropyLoss()
        self.device = 'cuda'
    
    def distribute(self):
        self.global_model.zero_agg_factor()
        for c in self.clients:
            c.download(self.global_model)
        
    def model_update(self, cr):
        for c in self.clients:
            c.local_update(cr)
            print("\rComm round %d: updating in client %d   " % (cr, c.client_id), end='')
            

    def agg(self, cr, info):
        idx_list = list(range(len(self.clients)))
        self.global_model.zero_agg_factor()
        self.global_model.to(self.device)
        for c_idx in idx_list:
            self.global_model.aggregate(self.clients[c_idx].loc_model, self.clients[c_idx].num_data)
            print("\rComm round %d: aggregated client %d   " % (cr, c_idx), end='')
    
    def eval_test(self):
        test_loss = 0
        correct = 0
        total = 0
        self.global_model.eval()
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(self.testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.global_model(images)[:5]
                loss = self.criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        acc = 100.*correct/total
        test_loss /= total
        print('\n Test acc is ', acc)
        return acc, test_loss
        
    def fedlearn(self, total_cr, info):
        acc_list = []
        test_loss_list = []
        for cr in range(total_cr):
            self.distribute()                 # 服务器分发模型到各个客户端
            self.model_update(cr)             # 客户端本地更新模型
            self.agg(cr, info)                # 客户端上传模型,在服务器端聚合
            acc, test_loss = self.eval_test() # 测试联邦学习的全局模型精度
            acc_list.append(acc)
            test_loss_list.append(test_loss)
            
 def main():
    seed = 0
    rn.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    num_users = 100
    data_dir = './dataset/MNIST'
    # IID data distribution!!!!!!
    train_dataset, test_dataset, dict_users = get_dataset_MNIST(data_dir, num_users, iid=False)
    Client.train_set = train_dataset
    Client.test_set = test_dataset
    clients = []
    epoch, batch_size, lr, cr = 2, 64, 0.05, 10
    for ir in range(1,2): # 做实验一般跑几遍取平均值,这里只跑一次
        info = (epoch, batch_size, lr, ir)
        print("***********************************************************************")
        print("******* ir = %d, epoch=%d, batch_size=%d, lr=%.4f starting ***********"%(info[3], info[0],info[1],info[2])) # 超参设定
        print("***********************************************************************")
        Client.local_epoch = epoch
        Client.batch_size = batch_size
        Client.lr = lr
        for user_idx in range(num_users):
            clients.append(Client(dict_users[user_idx], Net(), user_idx, test_dataset))
        for i in range(num_users):
            print("clients[%d].num_data = "%(i), clients[i].num_data)
        server = Server(clients, test_dataset)
        server.fedlearn(cr, info)
        clients.clear()           
  
main()

***********************************************************************
******* ir = 1, epoch=2, batch_size=64, lr=0.0500 starting ***********
***********************************************************************
clients[0].num_data =  534
clients[1].num_data =  787
clients[2].num_data =  921
clients[3].num_data =  664
clients[4].num_data =  647
clients[5].num_data =  693
clients[6].num_data =  456
clients[7].num_data =  290
clients[8].num_data =  659
clients[9].num_data =  632
clients[10].num_data =  787
clients[11].num_data =  706
clients[12].num_data =  544
clients[13].num_data =  587
clients[14].num_data =  667
clients[15].num_data =  711
clients[16].num_data =  865
clients[17].num_data =  577
clients[18].num_data =  620
clients[19].num_data =  594
clients[20].num_data =  545
clients[21].num_data =  510
clients[22].num_data =  528
clients[23].num_data =  1263
clients[24].num_data =  600
clients[25].num_data =  633
clients[26].num_data =  718
clients[27].num_data =  1180
clients[28].num_data =  724
clients[29].num_data =  625
clients[30].num_data =  688
clients[31].num_data =  465
clients[32].num_data =  296
clients[33].num_data =  753
clients[34].num_data =  974
clients[35].num_data =  221
clients[36].num_data =  416
clients[37].num_data =  519
clients[38].num_data =  499
clients[39].num_data =  382
clients[40].num_data =  159
clients[41].num_data =  609
clients[42].num_data =  469
clients[43].num_data =  453
clients[44].num_data =  256
clients[45].num_data =  857
clients[46].num_data =  554
clients[47].num_data =  655
clients[48].num_data =  354
clients[49].num_data =  223
clients[50].num_data =  361
clients[51].num_data =  774
clients[52].num_data =  329
clients[53].num_data =  287
clients[54].num_data =  728
clients[55].num_data =  814
clients[56].num_data =  678
clients[57].num_data =  610
clients[58].num_data =  973
clients[59].num_data =  360
clients[60].num_data =  1262
clients[61].num_data =  680
clients[62].num_data =  464
clients[63].num_data =  698
clients[64].num_data =  778
clients[65].num_data =  609
clients[66].num_data =  498
clients[67].num_data =  704
clients[68].num_data =  212
clients[69].num_data =  95
clients[70].num_data =  563
clients[71].num_data =  802
clients[72].num_data =  805
clients[73].num_data =  589
clients[74].num_data =  698
clients[75].num_data =  439
clients[76].num_data =  542
clients[77].num_data =  362
clients[78].num_data =  1073
clients[79].num_data =  654
clients[80].num_data =  624
clients[81].num_data =  715
clients[82].num_data =  629
clients[83].num_data =  642
clients[84].num_data =  308
clients[85].num_data =  370
clients[86].num_data =  563
clients[87].num_data =  334
clients[88].num_data =  513
clients[89].num_data =  1024
clients[90].num_data =  640
clients[91].num_data =  349
clients[92].num_data =  742
clients[93].num_data =  443
clients[94].num_data =  782
clients[95].num_data =  881
clients[96].num_data =  453
clients[97].num_data =  354
clients[98].num_data =  617
clients[99].num_data =  507

Comm round 0: aggregated client 99    
 Test acc is  71.01
Comm round 1: aggregated client 99    
 Test acc is  83.05
Comm round 2: aggregated client 99    
 Test acc is  91.71
Comm round 3: aggregated client 99    
 Test acc is  91.13
Comm round 4: aggregated client 99    
 Test acc is  93.64
Comm round 5: aggregated client 99    
 Test acc is  93.88
Comm round 6: aggregated client 99    
 Test acc is  94.51
Comm round 7: aggregated client 99    
 Test acc is  94.78
Comm round 8: aggregated client 99    
 Test acc is  95.06
Comm round 9: aggregated client 99    
 Test acc is  95.26 

如何使前五个客户端在本地训练呢

  • 写回答

1条回答 默认 最新

  • 有问必答小助手 2021-11-30 09:45
    关注

    你好,我是有问必答小助手,非常抱歉,本次您提出的有问必答问题,技术专家团超时未为您做出解答


    本次提问扣除的有问必答次数,将会以问答VIP体验卡(1次有问必答机会、商城购买实体图书享受95折优惠)的形式为您补发到账户。


    因为有问必答VIP体验卡有效期仅有1天,您在需要使用的时候【私信】联系我,我会为您补发。

    评论

报告相同问题?

问题事件

  • 系统已结题 12月6日
  • 创建了问题 11月28日

悬赏问题

  • ¥15 Windows server update services
  • ¥15 关于#c语言#的问题:我现在在做一个墨水屏设计,2.9英寸的小屏怎么换4.2英寸大屏
  • ¥15 模糊pid与pid仿真结果几乎一样
  • ¥15 java的GUI的运用
  • ¥15 Web.config连不上数据库
  • ¥15 我想付费需要AKM公司DSP开发资料及相关开发。
  • ¥15 怎么配置广告联盟瀑布流
  • ¥15 Rstudio 保存代码闪退
  • ¥20 win系统的PYQT程序生成的数据如何放入云服务器阿里云window版?
  • ¥50 invest生境质量模块