• Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Getting Started
    • What's Fed-BioMed
    • Basic Example
    • Fedbiomed Architecture
    • Fedbiomed Workflow
  • Tutorials
    • Installation
      • Software Installation
      • Setting Up Environment
    • PyTorch
      • PyTorch MNIST Basic Example
      • How to Create Your Custom PyTorch Training Plan
      • MNIST classification with PyTorch, comparing federated model vs model trained locally
      • PyTorch Used Cars Dataset Example
      • PyTorch aggregation methods in Fed-BioMed
    • MONAI
      • Federated 2d image classification with MONAI
      • Federated 2d XRay registration with MONAI
    • Scikit-Learn
      • MNIST classification with Scikit-Learn Classifier (Perceptron)
      • Fed-BioMed to train a federated SGD regressor model
      • Implementing other Scikit Learn models for Federated Learning
    • FLamby
      • General Concepts
      • FLamby integration in Fed-BioMed
    • Advanced
      • In Depth Experiment Configuration
      • PyTorch model training using a GPU
      • Breakpoints
    • Security
      • Using Differential Privacy with OPACUS on Fed-BioMed
      • Local and Central DP with Fed-BioMed: MONAI 2d image registration
      • Training Process with Training Plan Management
    • Biomedical data
      • Brain Segmentation
  • User Guide
    • Glossary
    • Deployment
      • Introduction
      • VPN Deployment
      • Network matrix
      • Security model
    • Node
      • Configuring Nodes
      • Deploying Datasets
      • Training Plan Management
      • Using GPU
      • Node GUI
    • Researcher
      • Training Plan
      • Training Data
      • Experiment
      • Aggregation
      • Listing Datasets and Selecting Nodes
      • Model Validation on the Node Side
      • Tensorboard
  • Developer
    • API Reference
      • Common
        • Constants
        • Data
        • Environ
        • Exceptions
        • Json
        • Logger
        • Message
        • Messaging
        • Repository
        • TasksQueue
        • TrainingPlans
        • TrainingArgs
        • Utils
        • Validator
      • Node
        • CLI
        • DatasetManager
        • Node
        • TrainingPlanSecutiryManager
        • HistoryMonitor
        • Round
      • Researcher
        • Aggregators
        • Datasets
        • Experiment
        • Job
        • Monitor
        • Responses
        • Requests
        • Strategies
        • Filetools
    • Usage and Tools
    • Continuous Integration
Download Notebook

How to Create Your Custom PyTorch Training Plan¶

Fed-BioMed allows you to perform model training without changing your PyTorch training plan class completely. Integrating your PyToch model to Fed-BioMed only requires to add extra attributes and methods to train your model based on a federated approach. In this tutorial, you will learn how to write/define your TrainingPlan (wrapping your model) in Fed-BioMed for PyTorch framework.

Note: Before starting this tutorial we highly recommend you to follow the previous tutorials to understand the basics of Fed-BioMed.

In this tutorial, we will be using Celaba (CelebaFaces) dataset to train the model. You can see details of the dataset here. In the following sections, you will have the instructions for downloading and configuring Celeba dataset for Fed-BioMed framework.

1. Fed-BioMed Training Plan¶

In this section, you will learn how to write your custom training plan.

What is Training Plan?¶

The training plan is the class where all the methods and atributes are defined to train your model on the nodes. Each training plan should inherit the base training plan class of the belonging ML framework that is provided by Fed-BioMed. For more details, you can visit documentation for training plan. The following code snippet shows a basic training plan that can be defined in Fed-BioMed for PyTorch framework.

from fedbiomed.common.training_plans import TorchTrainingPlan


class CustomTrainingPlan(TorchTrainingPlan):
    def init_model(self, model_args):
        # ....
        pass

    def init_dependencies(self):
        #...
        pass

    def training_data(self,  batch_size = 48):
        # ...
        return

    def training_step(self, data, target):
        # ...
        return

init_model Method of Training Plan¶

init_model method of the training plan is where you initialize your neural network module as in classical PyTorch model class. The network should be defined inside the training plan class and init_model should instantiate this network (Module), and return it.

In this tutorial, we will be training a classification model for CelebA image dataset that will be able to predict whether the given face is smiling.

def init_model(self, model_args: dict = {}):
    return self.Net(model_args)

class Net(nn.Module):

    def __init__(model_args):
        super().__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 32, 3, 1)
        self.conv4 = nn.Conv2d(32, 32, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        # Classifier
        self.fc1 = nn.Linear(3168, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):

        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = self.conv3(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = self.conv4(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)

        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

init_dependencies Method¶

Next, you should define the init_dependencies to declare the modules that are used in the training plan. The modules should be supported by the Fed-BioMed.

def init_depedencies(self)
    # Here we define the custom dependencies that will be needed by our custom Dataloader
    deps = ["from torch.utils.data import Dataset, DataLoader",
            "from torchvision import transforms",
            "import pandas as pd",
            "from PIL import Image",
            "import os",
            "import numpy as np"]
    return deps

training_data() and Custom Dataset¶

training_data is a method where the data is loaded for training on the node side. During each round of training, each node that particapates federated training builds the model, load the dataset using the method training_data, and performs the training_step by passing loaded dataset.

The dataset that we will be using in this tutorial is a image dataset. Therefore, your custom PyTorch Dataset should be be able to load images by given index . Please see the details of custom PyTorch datasets.

class CelebaDataset(Dataset):
        """Custom Dataset for loading CelebA face images"""


        def __init__(self, txt_path, img_dir, transform=None):

            # Read the csv file that includes classes for each image
            df = pd.read_csv(txt_path, sep="\t", index_col=0)
            self.img_dir = img_dir
            self.txt_path = txt_path
            self.img_names = df.index.values
            self.y = df['Smiling'].values
            self.transform = transform

        def __getitem__(self, index):
            img = np.asarray(Image.open(os.path.join(self.img_dir, self.img_names[index])))
            img = transforms.ToTensor()(img)
            label = self.y[index]
            return img, label

        def __len__(self):
            return self.y.shape[0]

Now, you need to define a training_data method that will create a Fed-BioMed DataManager using custom CelebaDataset class.

def training_data(self,  batch_size = 48):
        # The training_data creates the dataset and returns DataManager to be used for training in the general class Torchnn of Fed-BioMed
        dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/")
        loader_arguments = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **loader_arguments)

training_step()¶

The last method that needs to be defined is the training_step. This method is responsible for executing the forward method and calculating the loss value for the backward process of the network. To access the forward method of the torch.nn.Module that is defined in the init_model, the getter method model() of training plan class should be used.

def training_step(self, data, target): 
    output = self.model().forward(data)
    loss   = torch.nn.functional.nll_loss(output, target)
    return loss

You are now ready to create your training plan class. All you need to do is to locate every method that has been explained in the previous sections in your traning plan class. In the next steps we will;

  1. download the CelebA dataset and deploy it on the nodes
  2. define our complete training
  3. create an experiment and run it
  4. evaluate our model using a testing dataset

2.Configuring Nodes¶

We will be working with CelebA (CelebFaces) dataset. Therefore, please visit here and download the files img/img_align_celeba.zip and Anno/list_attr_celeba.txt. After the download operation is completed;

  • Please go to ./notebooks/data/Celeba in Fed-BioMed project.
  • Create Celeba_raw/raw directory and copy the list_attr_celeba.txt file.
  • Extract the zip file img_align_celeba.zip

Your folder should be same as the tree below;

Celeba
    README.md
    create_node_data.py    
    .gitignore

    Celeba_raw
        raw
            list_attr_celeba.txt
            img_align_celeba.zip
            img_align_celeba
              lots of images

The dataset has to be processed and split to create three distinct datasets for Node 1, Node 2, and Node 3. You can do it easily by running the following script in your notebook. If you are working in a different directory than the fedbiomed/notebooks, please make sure that you define the correct paths in the following example.

Running the following scripts might take some time, please be patient.

In [ ]:
Copied!
import os
import numpy as np
import pandas as pd
import shutil

# Celeba folder
parent_dir = os.path.join(".", "data", "Celeba") 
celeba_raw_folder = os.path.join("Celeba_raw", "raw")
img_dir = os.path.join(parent_dir, celeba_raw_folder, 'img_align_celeba') + os.sep
out_dir = os.path.join(".", "data", "Celeba", "celeba_preprocessed")

# Read attribute CSV and only load Smilling column
df = pd.read_csv(os.path.join(parent_dir, celeba_raw_folder, 'list_attr_celeba.txt'),
                 sep="\s+", skiprows=1, usecols=['Smiling'])

# data is on the form : 1 if the person is smiling, -1 otherwise. we set all -1 to 0 for the model to train faster
df.loc[df['Smiling'] == -1, 'Smiling'] = 0

# Split csv in 3 part
length = len(df)
data_node_1 = df.iloc[:int(length/3)]
data_node_2 = df.iloc[int(length/3):int(length/3) * 2]
data_node_3 = df.iloc[int(length/3) * 2:]

# Create folder for each node
if not os.path.exists(os.path.join(out_dir, "data_node_1")):
    os.makedirs(os.path.join(out_dir, "data_node_1", "data"))
if not os.path.exists(os.path.join(out_dir, "data_node_2")):
    os.makedirs(os.path.join(out_dir, "data_node_2", "data"))
if not os.path.exists(os.path.join(out_dir, "data_node_3")):
    os.makedirs(os.path.join(out_dir, "data_node_3", "data"))

# Save each node's target CSV to the corect folder
data_node_1.to_csv(os.path.join(out_dir, 'data_node_1', 'target.csv'), sep='\t')
data_node_2.to_csv(os.path.join(out_dir, 'data_node_2', 'target.csv'), sep='\t')
data_node_3.to_csv(os.path.join(out_dir, 'data_node_3', 'target.csv'), sep='\t')

# Copy all images of each node in the correct folder
for im in data_node_1.index:
    shutil.copy(img_dir+im, os.path.join(out_dir,"data_node_1", "data", im))
print("data for node 1 succesfully created")

for im in data_node_2.index:
    shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_2", "data", im))
print("data for node 2 succesfully created")

for im in data_node_3.index:
    shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_3", "data", im))
print("data for node 3 succesfully created")
import os import numpy as np import pandas as pd import shutil # Celeba folder parent_dir = os.path.join(".", "data", "Celeba") celeba_raw_folder = os.path.join("Celeba_raw", "raw") img_dir = os.path.join(parent_dir, celeba_raw_folder, 'img_align_celeba') + os.sep out_dir = os.path.join(".", "data", "Celeba", "celeba_preprocessed") # Read attribute CSV and only load Smilling column df = pd.read_csv(os.path.join(parent_dir, celeba_raw_folder, 'list_attr_celeba.txt'), sep="\s+", skiprows=1, usecols=['Smiling']) # data is on the form : 1 if the person is smiling, -1 otherwise. we set all -1 to 0 for the model to train faster df.loc[df['Smiling'] == -1, 'Smiling'] = 0 # Split csv in 3 part length = len(df) data_node_1 = df.iloc[:int(length/3)] data_node_2 = df.iloc[int(length/3):int(length/3) * 2] data_node_3 = df.iloc[int(length/3) * 2:] # Create folder for each node if not os.path.exists(os.path.join(out_dir, "data_node_1")): os.makedirs(os.path.join(out_dir, "data_node_1", "data")) if not os.path.exists(os.path.join(out_dir, "data_node_2")): os.makedirs(os.path.join(out_dir, "data_node_2", "data")) if not os.path.exists(os.path.join(out_dir, "data_node_3")): os.makedirs(os.path.join(out_dir, "data_node_3", "data")) # Save each node's target CSV to the corect folder data_node_1.to_csv(os.path.join(out_dir, 'data_node_1', 'target.csv'), sep='\t') data_node_2.to_csv(os.path.join(out_dir, 'data_node_2', 'target.csv'), sep='\t') data_node_3.to_csv(os.path.join(out_dir, 'data_node_3', 'target.csv'), sep='\t') # Copy all images of each node in the correct folder for im in data_node_1.index: shutil.copy(img_dir+im, os.path.join(out_dir,"data_node_1", "data", im)) print("data for node 1 succesfully created") for im in data_node_2.index: shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_2", "data", im)) print("data for node 2 succesfully created") for im in data_node_3.index: shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_3", "data", im)) print("data for node 3 succesfully created")

Now if you go to the ${FEDBIOMED_DIR}/notebooks/data/Celaba directory you can see the folder called celeba_preprocessed. There will be three different folders that contain an image dataset for 3 nodes. The next step will be configuring the nodes and deplying the datasets. In the next steps, we will be configuring only two nodes. The dataset for the third node is going to be used for the testing.

Create 2 nodes for training :

  • ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config node1.ini start
  • ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config node2.ini start

Add data to each node :

  • ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config node1.ini add
  • ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config node2.ini add

Note: ${FEDBIOMED_DIR} is a path relative to based directory of the cloned Fed-BioMed repository. You can set it by running command export FEDBIOMED_DIR=/path/to/fedbiomed. This is not required for Fed-BioMed to work but enables you to run the tutorials more easily.

2.1. Configuration Steps¶

It is necessary to previously configure at least a node:

  1. ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config (ini file) add

    • Select option 4 (images) to add an image dataset to the node
    • Add a name and the tag for the dataset (tag should contain '#celeba' as it is the tag used for this training) and finally add the description
    • Pick a data folder from the 3 generated datasets inside data/Celeba/celeba_preprocessed (eg: data_node_1)
    • Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  2. Check that your data has been added by executing ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config (ini file) list

  3. Run the node using ${FEDBIOMED_DIR}/scripts/fedbiomed_run node config <ini file> start. Wait until you get Starting task manager. it means you are online.

After the steps above are completed, you will be ready to train your classification model on two different nodes.

3. Defining Custom PyTorch Model and Training Plan¶

Next step is to create our Net class based on the methods that have been explained in the previous sections. This class is part of the training plan that will be passed to the Experiment. Afterwards, the nodes will receive the training plan and perform the training by retrieving training data and passing it to the training_step.

In [ ]:
Copied!
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
from fedbiomed.common.data import DataManager
import pandas as pd
import numpy as np
from PIL import Image
import os


class CelebaTrainingPlan(TorchTrainingPlan):

    # Defines model
    def init_model(self):
        model = self.Net()
        return model

    # Here we define the custom dependencies that will be needed by our custom Dataloader
    def init_dependencies(self):
        deps = ["from torch.utils.data import Dataset",
                "from torchvision import transforms",
                "import pandas as pd",
                "from PIL import Image",
                "import os",
                "import numpy as np"]
        return deps

    # Torch modules class
    class Net(nn.Module):

        def __init__(self):
            super().__init__()
            #convolution layers
            self.conv1 = nn.Conv2d(3, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 32, 3, 1)
            self.conv3 = nn.Conv2d(32, 32, 3, 1)
            self.conv4 = nn.Conv2d(32, 32, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            # classifier
            self.fc1 = nn.Linear(3168, 128)
            self.fc2 = nn.Linear(128, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv2(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv3(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv4(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)

            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output


    class CelebaDataset(Dataset):
        """Custom Dataset for loading CelebA face images"""

        # we dont load the full data of the images, we retrieve the image with the get item.
        # in our case, each image is 218*178 * 3colors. there is 67533 images. this take at leas 7G of ram
        # loading images when needed takes more time during training but it wont impact the ram usage as much as loading everything
        def __init__(self, txt_path, img_dir, transform=None):
            df = pd.read_csv(txt_path, sep="\t", index_col=0)
            self.img_dir = img_dir
            self.txt_path = txt_path
            self.img_names = df.index.values
            self.y = df['Smiling'].values
            self.transform = transform
            print("celeba dataset finished")

        def __getitem__(self, index):
            img = np.asarray(Image.open(os.path.join(self.img_dir,
                                        self.img_names[index])))
            img = transforms.ToTensor()(img)
            label = self.y[index]
            return img, label

        def __len__(self):
            return self.y.shape[0]

    # The training_data creates the Dataloader to be used for training in the
    # general class Torchnn of fedbiomed
    def training_data(self,  batch_size = 48):
        dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/")
        loader_arguments = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **loader_arguments)

    # This function must return the loss to backward it
    def training_step(self, data, target):

        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss
import torch import torch.nn as nn from fedbiomed.common.training_plans import TorchTrainingPlan import torch.nn.functional as F from torchvision import transforms from torch.utils.data import Dataset from fedbiomed.common.data import DataManager import pandas as pd import numpy as np from PIL import Image import os class CelebaTrainingPlan(TorchTrainingPlan): # Defines model def init_model(self): model = self.Net() return model # Here we define the custom dependencies that will be needed by our custom Dataloader def init_dependencies(self): deps = ["from torch.utils.data import Dataset", "from torchvision import transforms", "import pandas as pd", "from PIL import Image", "import os", "import numpy as np"] return deps # Torch modules class class Net(nn.Module): def __init__(self): super().__init__() #convolution layers self.conv1 = nn.Conv2d(3, 32, 3, 1) self.conv2 = nn.Conv2d(32, 32, 3, 1) self.conv3 = nn.Conv2d(32, 32, 3, 1) self.conv4 = nn.Conv2d(32, 32, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) # classifier self.fc1 = nn.Linear(3168, 128) self.fc2 = nn.Linear(128, 2) def forward(self, x): x = self.conv1(x) x = F.max_pool2d(x, 2) x = F.relu(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = F.relu(x) x = self.conv3(x) x = F.max_pool2d(x, 2) x = F.relu(x) x = self.conv4(x) x = F.max_pool2d(x, 2) x = F.relu(x) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output class CelebaDataset(Dataset): """Custom Dataset for loading CelebA face images""" # we dont load the full data of the images, we retrieve the image with the get item. # in our case, each image is 218*178 * 3colors. there is 67533 images. this take at leas 7G of ram # loading images when needed takes more time during training but it wont impact the ram usage as much as loading everything def __init__(self, txt_path, img_dir, transform=None): df = pd.read_csv(txt_path, sep="\t", index_col=0) self.img_dir = img_dir self.txt_path = txt_path self.img_names = df.index.values self.y = df['Smiling'].values self.transform = transform print("celeba dataset finished") def __getitem__(self, index): img = np.asarray(Image.open(os.path.join(self.img_dir, self.img_names[index]))) img = transforms.ToTensor()(img) label = self.y[index] return img, label def __len__(self): return self.y.shape[0] # The training_data creates the Dataloader to be used for training in the # general class Torchnn of fedbiomed def training_data(self, batch_size = 48): dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/") loader_arguments = {'batch_size': batch_size, 'shuffle': True} return DataManager(dataset, **loader_arguments) # This function must return the loss to backward it def training_step(self, data, target): output = self.model().forward(data) loss = torch.nn.functional.nll_loss(output, target) return loss

This group of arguments corresponds respectively to:

  • model_args: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node-side.
  • training_args: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node-side.

Note: Typos and/or lack of positional (required) arguments might raise an error.

In [ ]:
Copied!
training_args = {
    'batch_size': 32, 
    'optimizer_args': {
        'lr': 1e-3
    },
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
training_args = { 'batch_size': 32, 'optimizer_args': { 'lr': 1e-3 }, 'epochs': 1, 'dry_run': False, 'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples }

4. Training Federated Model¶

To provide training orchestration over two nodes we need to define an experiment which:

  • searches nodes serving data for the tags,
  • defines the local training on nodes with the training plan saved in training_plan_path, and federates all local updates at each round with aggregator
  • runs training for round_limit.

You can visit user guide to know much more about experiment.

In [ ]:
Copied!
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#celeba']
rounds = 3

exp = Experiment(tags=tags,
                 training_plan_class=CelebaTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)
from fedbiomed.researcher.experiment import Experiment from fedbiomed.researcher.aggregators.fedavg import FedAverage tags = ['#celeba'] rounds = 3 exp = Experiment(tags=tags, training_plan_class=CelebaTrainingPlan, training_args=training_args, round_limit=rounds, aggregator=FedAverage(), node_selection_strategy=None)

Let's start the experiment.

By default, this function doesn't stop until all the round_limit rounds are done for all the nodes. While the experiment runs you can open the terminals where you have started the nodes and see the training progress. However, the loss values obtained from each node during the training will be printed as output in real time. Since we are working on an image dataset, training might take some time.

In [ ]:
Copied!
exp.run()
exp.run()

Loading Training Parameters¶

After all the rounds have been completed, you retrieve the aggregated parameters from the last round and load them.

In [ ]:
Copied!
fed_model = exp.training_plan().model()
fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])
fed_model = exp.training_plan().model() fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

5. Testing Federated Model¶

We will define a testing routine to extract the accuracy metrics on the testing dataset. We will use the dataset that has been extracted into data_node_3.

In [ ]:
Copied!
import torch
import torch.nn as nn

import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from PIL import Image
import os

def testing_Accuracy(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0

    device = "cpu"

    correct = 0

    loader_size = len(data_loader)
    with torch.no_grad():
        for idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            #only uses 10% of the dataset, results are similar but faster
            if idx >= loader_size / 10:
                pass
                break

    
        pred = output.argmax(dim=1, keepdim=True)

    test_loss /= len(data_loader.dataset)
    accuracy = 100* correct/(data_loader.batch_size * idx)

    return(test_loss, accuracy)
import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torch.utils.data import Dataset import pandas as pd import numpy as np from PIL import Image import os def testing_Accuracy(model, data_loader): model.eval() test_loss = 0 correct = 0 device = "cpu" correct = 0 loader_size = len(data_loader) with torch.no_grad(): for idx, (data, target) in enumerate(data_loader): data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() #only uses 10% of the dataset, results are similar but faster if idx >= loader_size / 10: pass break pred = output.argmax(dim=1, keepdim=True) test_loss /= len(data_loader.dataset) accuracy = 100* correct/(data_loader.batch_size * idx) return(test_loss, accuracy)

We also need to define a custom Dataset class for the test dataset in order to load it using PyTorch's DataLoader. This will be the same class that has been already defined in the training plan.

In [ ]:
Copied!
from torch.utils.data import DataLoader
test_dataset_path = "../data/celeba/celeba_preprocessed/data_node_3"

class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, txt_path, img_dir, transform=None):
        df = pd.read_csv(txt_path, sep="\t", index_col=0)
        self.img_dir = img_dir
        self.txt_path = txt_path
        self.img_names = df.index.values
        self.y = df['Smiling'].values
        self.transform = transform
        print("celeba dataset finished")

    def __getitem__(self, index):
        img = np.asarray(Image.open(os.path.join(self.img_dir,
                                        self.img_names[index])))
        img = transforms.ToTensor()(img)
        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]


dataset = CelebaDataset(test_dataset_path + "/target.csv", test_dataset_path + "/data/")
train_kwargs = {'batch_size': 128, 'shuffle': True}
data_loader = DataLoader(dataset, **train_kwargs)
from torch.utils.data import DataLoader test_dataset_path = "../data/celeba/celeba_preprocessed/data_node_3" class CelebaDataset(Dataset): """Custom Dataset for loading CelebA face images""" def __init__(self, txt_path, img_dir, transform=None): df = pd.read_csv(txt_path, sep="\t", index_col=0) self.img_dir = img_dir self.txt_path = txt_path self.img_names = df.index.values self.y = df['Smiling'].values self.transform = transform print("celeba dataset finished") def __getitem__(self, index): img = np.asarray(Image.open(os.path.join(self.img_dir, self.img_names[index]))) img = transforms.ToTensor()(img) label = self.y[index] return img, label def __len__(self): return self.y.shape[0] dataset = CelebaDataset(test_dataset_path + "/target.csv", test_dataset_path + "/data/") train_kwargs = {'batch_size': 128, 'shuffle': True} data_loader = DataLoader(dataset, **train_kwargs)
In [ ]:
Copied!
acc_federated = testing_Accuracy(fed_model, data_loader)
acc_federated[1]
acc_federated = testing_Accuracy(fed_model, data_loader) acc_federated[1]

Conclusions¶

In this tutorial, running a custom model on Fed-BioMed (by wrapping it in a custom training plan) for the PyTorch framework has been explained. Because the examples are designed for the development environment, we have been running nodes in the same host machine. In production, the nodes that you need to use to train your model will serve in remote servers. Since Fed-BioMed is still in the development phase, in future there might be updates in the function and the methods of these tutorials. Therefore, please keep you updated from our GitLab repository.

Download Notebook
  • 1. Fed-BioMed Training Plan
    • What is Training Plan?
    • init_model Method of Training Plan
    • init_dependencies Method
    • training_data() and Custom Dataset
    • training_step()
  • 2.Configuring Nodes
    • 2.1. Configuration Steps
  • 3. Defining Custom PyTorch Model and Training Plan
  • 4. Training Federated Model
    • Loading Training Parameters
  • 5. Testing Federated Model
  • Conclusions
Address:

2004 Rte des Lucioles, 06902 Sophia Antipolis

E-mail:

fedbiomed _at_ inria _dot_ fr

Fed-BioMed © 2022