• Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Getting Started
    • What's Fed-BioMed
    • Basic Example
  • Tutorials
    • Installation
      • Software Installation
      • Setting Up Environment
    • PyTorch
      • PyTorch MNIST Basic Example
      • How to Create Your Custom PyTorch Model
      • MNIST classification with PyTorch, comparing federated model vs model trained locally
      • PyTorch Used Cars Dataset Example
    • MONAI
      • Federated 2d image classification with MONAI
      • Federated 2d XRay registration with MONAI
    • Scikit-Learn
      • MNIST classification with Scikit-Learn Classifier (Perceptron)
      • Fedbiomed to train a federated SGD regressor model
      • Implementing other Scikit Learn models for Federated Learning
    • Advanced
      • In Depth Experiment Configuration
      • PyTorch model training using a GPU
      • Breakpoints
    • Security
      • Using Differential Privacy with OPACUS on Fed-BioMed
      • Training with Approved Models Files
  • User Guide
    • Glossary
    • Node
      • Configuring Nodes
      • Deploying Datasets
      • Model Management
      • Using GPU
    • Researcher
      • Training Plan
      • Training Data
      • Experiment
      • Aggregation
      • Listing Datasets and Selecting Nodes
      • Tensorboard
  • Developer
    • Usage and Tools
    • Continuous Integration
Download Notebook

Using Differential Privacy with OPACUS on Fed-BioMed¶

In this notebook we show how opacus (https://opacus.ai/) can be used in Fed-BioMed. Opacus is a library which allows to train PyTorch models with differential privacy. We will train the basic MNIST example using two nodes.

Setting up Fed-BioMed Environment¶

Start the network¶

Before running this notebook, start the network with ./scripts/fedbiomed_run network

Setting the node up¶

It is necessary to previously configure a node:

  1. ./scripts/fedbiomed_run node add

    • Select option 2 (default)
    • Confirm default tags by hitting "y" and ENTER
    • Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
    • Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  2. Check that your data has been added by executing ./scripts/fedbiomed_run node list

  3. Run the node using ./scripts/fedbiomed_run node run. Wait until you get Starting task manager. it means you are online.

Define a model and parameters¶

Declare a torch.nn MyTrainingPlan class to send for training on the node

In [ ]:
Copied!
from fedbiomed.researcher.environ import environ
import tempfile
tmp_dir_model = tempfile.TemporaryDirectory(dir=environ['TMP_DIR']+'/')
model_file = tmp_dir_model.name + '/class_export_mnist.py'
from fedbiomed.researcher.environ import environ import tempfile tmp_dir_model = tempfile.TemporaryDirectory(dir=environ['TMP_DIR']+'/') model_file = tmp_dir_model.name + '/class_export_mnist.py'

In the cell below, we are going to define the model using opacus for differential privacy. For this example, we are going to use the function make_private from opacus.privacy_engine. Two hyperparameters should be defined:

  • noise_multiplier: The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is added (How much noise to add)
  • max_grad_norm: The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.

It is worth noting that in order to use the opacus PrivacyEngine class we need to properly define as training plan attributes a model, a dataloader and an optimizer.

In [ ]:
Copied!
%%writefile "$model_file"

import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.torchnn import TorchTrainingPlan
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from opacus import PrivacyEngine 

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args):
        super(MyTrainingPlan, self).__init__()
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        # In this case, we need the torch DataLoader classes
        # Since we will train on MNIST, we need datasets and transform from torchvision
        deps = ["from torchvision import datasets, transforms",
                "import torch.nn.functional as F",
                "from torch.utils.data import DataLoader",
                "from opacus import PrivacyEngine",]
        self.add_dependency(deps)
        
        self.model = self.make_model()
        
        self.noise_multiplier = model_args['noise_multiplier']
        self.max_grad_norm = model_args['max_grad_norm']
        
    def make_model(self):
        model = nn.Sequential(nn.Conv2d(1, 32, 3, 1),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 3, 1),
                                  nn.ReLU(),
                                  nn.MaxPool2d(2),
                                  nn.Dropout(0.25),
                                  nn.Flatten(),
                                  nn.Linear(9216, 128),
                                  nn.ReLU(),
                                  nn.Dropout(0.5),
                                  nn.Linear(128, 10),
                                  nn.LogSoftmax(dim=1))
        return model
        
    def forward(self, x):
        return self.model(x)

    def training_data(self, batch_size = 48):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        data_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
        
        # enter PrivacyEngine
        privacy_engine = PrivacyEngine()
        self.model, self.optimizer, data_loader = privacy_engine.make_private(module=self.model,
                                                                    optimizer=self.optimizer,
                                                                    data_loader=data_loader,
                                                                    noise_multiplier=self.noise_multiplier,
                                                                    max_grad_norm=self.max_grad_norm,
                                                                    )
        return data_loader
    
    def training_step(self, data, target):
        output = self.forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss

    def postprocess(self, params):
        # params keys are changed by the privacy engine (as _module.param_key): should be re-named
        params_keys = list(params.keys())
        for key in params_keys:
            if '_module' in key:
                newkey = key.replace('_module.', '')
                params[newkey] = params.pop(key)
        return params
%%writefile "$model_file" import torch import torch.nn as nn import torch.nn.functional as F from fedbiomed.common.torchnn import TorchTrainingPlan from torch.utils.data import DataLoader from torchvision import datasets, transforms from opacus import PrivacyEngine # Here we define the model to be used. # You can use any class name (here 'Net') class MyTrainingPlan(TorchTrainingPlan): def __init__(self, model_args): super(MyTrainingPlan, self).__init__() # Here we define the custom dependencies that will be needed by our custom Dataloader # In this case, we need the torch DataLoader classes # Since we will train on MNIST, we need datasets and transform from torchvision deps = ["from torchvision import datasets, transforms", "import torch.nn.functional as F", "from torch.utils.data import DataLoader", "from opacus import PrivacyEngine",] self.add_dependency(deps) self.model = self.make_model() self.noise_multiplier = model_args['noise_multiplier'] self.max_grad_norm = model_args['max_grad_norm'] def make_model(self): model = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(), nn.Linear(9216, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 10), nn.LogSoftmax(dim=1)) return model def forward(self, x): return self.model(x) def training_data(self, batch_size = 48): # Custom torch Dataloader for MNIST data transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform) train_kwargs = {'batch_size': batch_size, 'shuffle': True} data_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) # enter PrivacyEngine privacy_engine = PrivacyEngine() self.model, self.optimizer, data_loader = privacy_engine.make_private(module=self.model, optimizer=self.optimizer, data_loader=data_loader, noise_multiplier=self.noise_multiplier, max_grad_norm=self.max_grad_norm, ) return data_loader def training_step(self, data, target): output = self.forward(data) loss = torch.nn.functional.nll_loss(output, target) return loss def postprocess(self, params): # params keys are changed by the privacy engine (as _module.param_key): should be re-named params_keys = list(params.keys()) for key in params_keys: if '_module' in key: newkey = key.replace('_module.', '') params[newkey] = params.pop(key) return params

This group of arguments correspond respectively:

  • model_args: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. For instance, the privacy parameters should be passed here.
  • training_args: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.

NOTE: typos and/or lack of positional (required) arguments will raise error. 🤓

In [ ]:
Copied!
model_args = {'noise_multiplier':1., 'max_grad_norm':1.0}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'epochs': 3, 
    'dry_run': False,  
    'batch_maxnum': 250 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
model_args = {'noise_multiplier':1., 'max_grad_norm':1.0} training_args = { 'batch_size': 48, 'lr': 1e-3, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 250 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples }

Declare and run the experiment¶

  • search nodes serving data for these tags, optionally filter on a list of node ID with nodes
  • run a round of local training on nodes with model defined in model_path + federation with aggregator
  • run for rounds rounds, applying the node_selection_strategy between the rounds
In [ ]:
Copied!
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 3

exp = Experiment(tags=tags,
                 #nodes=None,
                 model_path=model_file,
                 model_args=model_args,
                 model_class='MyTrainingPlan',
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)
from fedbiomed.researcher.experiment import Experiment from fedbiomed.researcher.aggregators.fedavg import FedAverage tags = ['#MNIST', '#dataset'] rounds = 3 exp = Experiment(tags=tags, #nodes=None, model_path=model_file, model_args=model_args, model_class='MyTrainingPlan', training_args=training_args, round_limit=rounds, aggregator=FedAverage(), node_selection_strategy=None)

Let's start the experiment.

By default, this function doesn't stop until all the rounds are done for all the nodes

In [ ]:
Copied!
exp.run()
exp.run()

Local training results for each round and each node are available in exp.training_replies() (index 0 to (rounds - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :

  • rtime_training real time (clock time) spent in the training function on the node
  • ptime_training process time (user and system CPU) spent in the training function on the node
  • rtime_total real time (clock time) spent in the researcher between sending the request and handling the response, at the Job() layer
In [ ]:
Copied!
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1].data()
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies()[rounds - 1].dataframe
print("\nList the training rounds : ", exp.training_replies().keys()) print("\nList the nodes for the last training round and their timings : ") round_data = exp.training_replies()[rounds - 1].data() for c in range(len(round_data)): print("\t- {id} :\ \n\t\trtime_training={rtraining:.2f} seconds\ \n\t\tptime_training={ptraining:.2f} seconds\ \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'], rtraining = round_data[c]['timing']['rtime_training'], ptraining = round_data[c]['timing']['ptime_training'], rtotal = round_data[c]['timing']['rtime_total'])) print('\n') exp.training_replies()[rounds - 1].dataframe

Federated parameters for each round are available in exp.aggregated_params() (index 0 to (rounds - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [ ]:
Copied!
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())
print("\nList the training rounds : ", exp.aggregated_params().keys()) print("\nAccess the federated params for the last training round :") print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path']) print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())

Testing¶

We define a little testing routine to extract the accuracy metrics on the testing dataset

In [ ]:
Copied!
import torch
import torch.nn.functional as F


def testing_Accuracy(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0
    device = 'cpu'

    correct = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

        pred = output.argmax(dim=1, keepdim=True)

    test_loss /= len(data_loader.dataset)
    accuracy = 100* correct/len(data_loader.dataset)

    return(test_loss, accuracy)
import torch import torch.nn.functional as F def testing_Accuracy(model, data_loader): model.eval() test_loss = 0 correct = 0 device = 'cpu' correct = 0 with torch.no_grad(): for data, target in data_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() pred = output.argmax(dim=1, keepdim=True) test_loss /= len(data_loader.dataset) accuracy = 100* correct/len(data_loader.dataset) return(test_loss, accuracy)
In [ ]:
Copied!
from torchvision import datasets, transforms
import os

local_mnist = os.path.join(environ['TMP_DIR'], 'local_mnist')

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

test_set = datasets.MNIST(root = local_mnist, download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)
from torchvision import datasets, transforms import os local_mnist = os.path.join(environ['TMP_DIR'], 'local_mnist') transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) test_set = datasets.MNIST(root = local_mnist, download = True, train = False, transform = transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)
In [ ]:
Copied!
fed_model = exp.model_instance()
fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

acc_federated = testing_Accuracy(fed_model, test_loader)

print('\nAccuracy federated training:  {:.4f}'.format(acc_federated[1]))

print('\nError federated training:  {:.4f}'.format(acc_federated[0]))
fed_model = exp.model_instance() fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params']) acc_federated = testing_Accuracy(fed_model, test_loader) print('\nAccuracy federated training: {:.4f}'.format(acc_federated[1])) print('\nError federated training: {:.4f}'.format(acc_federated[0]))
Download Notebook
  • Setting up Fed-BioMed Environment
    • Start the network
    • Setting the node up
  • Define a model and parameters
  • Declare and run the experiment
  • Testing
Address:

2004 Rte des Lucioles, 06902 Sophia Antipolis

E-mail:

fedbiomed _at_ inria _dot_ fr

Fed-BioMed © 2021