• Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • Contact Us
  • Getting Started
    • What's Fed-BioMed
    • Basic Example
    • Fedbiomed Architecture
    • Fedbiomed Workflow
  • Tutorials
    • Installation
      • Software Installation
      • Setting Up Environment
    • PyTorch
      • PyTorch MNIST Basic Example
      • How to Create Your Custom PyTorch Training Plan
      • MNIST classification with PyTorch, comparing federated model vs model trained locally
      • PyTorch Used Cars Dataset Example
      • PyTorch aggregation methods in Fed-BioMed
    • MONAI
      • Federated 2d image classification with MONAI
      • Federated 2d XRay registration with MONAI
    • Scikit-Learn
      • MNIST classification with Scikit-Learn Classifier (Perceptron)
      • Fed-BioMed to train a federated SGD regressor model
      • Implementing other Scikit Learn models for Federated Learning
    • FLamby
      • General Concepts
      • FLamby integration in Fed-BioMed
    • Advanced
      • In Depth Experiment Configuration
      • PyTorch model training using a GPU
      • Breakpoints
    • Security
      • Using Differential Privacy with OPACUS on Fed-BioMed
      • Local and Central DP with Fed-BioMed: MONAI 2d image registration
      • Training Process with Training Plan Management
    • Biomedical data
      • Brain Segmentation
  • User Guide
    • Glossary
    • Deployment
      • Introduction
      • VPN Deployment
      • Network matrix
      • Security model
    • Node
      • Configuring Nodes
      • Deploying Datasets
      • Training Plan Management
      • Using GPU
      • Node GUI
    • Researcher
      • Training Plan
      • Training Data
      • Experiment
      • Aggregation
      • Listing Datasets and Selecting Nodes
      • Model Validation on the Node Side
      • Tensorboard
  • Developer
    • API Reference
      • Common
        • Constants
        • Data
        • Environ
        • Exceptions
        • Json
        • Logger
        • Message
        • Messaging
        • Repository
        • TasksQueue
        • TrainingPlans
        • TrainingArgs
        • Utils
        • Validator
      • Node
        • CLI
        • DatasetManager
        • Node
        • TrainingPlanSecutiryManager
        • HistoryMonitor
        • Round
      • Researcher
        • Aggregators
        • Datasets
        • Experiment
        • Job
        • Monitor
        • Responses
        • Requests
        • Strategies
        • Filetools
    • Usage and Tools
    • Continuous Integration
Download Notebook

Brain Segmentation¶

This tutorial will show how to use Fed-BioMed to perform image segmentation on 3D medical MRI images of brains, using the publicly available IXI dataset. It uses a 3D U-Net model for the segmentation, trained on data from 3 separate centers.

Here we display a very complex case, using advanced Fed-BioMed functionalities such as:

  • loading a MedicalFolderDataset
  • implementing a custom Node Selection Strategy
  • setting a non-default Optimizer
  • monitoring training loss with Tensorboard

This tutorial is based on TorchIO's tutorial.

Automatic download and wrangling for the impatient¶

If you're not interested in the details, you may simply execute the download_and_split_ixi.py script provided by us, as explained below

mkdir -p ixi-data/notebooks/data
download_and_split_ixi.py -f ./ixi-data

After successfully running the command, follow the instructions printed to add the datasets and run the nodes. The tag used for this experiment is ixi-train.

Details about data preparation¶

If you just want to run the notebook, you may skip this section and skip to Define a new strategy.

First, download the IXI dataset from the Mendeley archive.

In this tutorial we are going to use the MedicalFolderDataset class provided by the Fed-BioMed library to load medical images in NIFTI format. Using this dataset class for image segmentation problems guarantees maximum compatibility with the rest of the Fed-BioMed functionalities and features.

Folder structure for MedicalFolderDataset¶

The MedicalFolderDataset is heavily inspired by PyTorch's ImageFolder Dataset, and requires you to manually prepare the image folders in order to respect a precise structure. The format assumes that you are dealing with imaging data, possibly acquired through multiple modalities, for different study subjects. Hence, you should provide one folder per subject, containing multiple subfolders for each image acquisition modality. Optionally, you may provide a csv file containing additional tabular data associated with each subject. This file is typically used for demographics data, and by default is called participants.csv.

_ root-folder
 |_ participants.csv
 |_ subject-1
 | |_ modality-1
 | |_ modality-2
 |_ subject-2
 | |_ modality-1
 | |_ modality-2
 |_ subject-3
 | |_ modality-1
 . .
 . .
 . .

Folder structure for this tutorial¶

In the specific case of this tutorial, we encourage you to further divide your images into additional subfolders, according to two criteria: the hospital that generated the data (there are three: Guys, HH and IOP) and a random train/holdout split. Note that each subject's folder will have a name with the following structure: IXI<SUBJECT_ID>-<HOSPITAL>-<RANDOM_ID>, for example IXI002-Guys-0828. In conclusion, combining the splits above with the structure required by the MedicalFolderDataset, your folder tree should look like this:

_root-folder
 |_ Guys
 | |_ train
 | | |_ participants.csv
 | | |_ IXI002-Guys-0828
 | | | |_ T1                <-- T1 is the first imaging modality
 | | | |_ T2
 | | | |_ label
 | | |_ IXI022-Guys-0701
 | | | |_ T1
 | | | |_ T2
 . . .
 . . .
 . . .
 | |_ holdout
 | | |_ participants.csv
 | | |_ IXI004-Guys-0321
 | | | |_ T1
 | | | |_ T2
 | | | |_ label
 | | | |_ T2
 . . .
 . . .
 . . .
 |_ HH
 | |_ train
 . . .
 . . .
 . . .
 | |_ holdout
 . . .
 . . .
 . . .
 |_ IOP
 . . .
 . . .

Add the IXI dataset to the federated nodes¶

For each of the three hospitals, create a federated node and add the corresponding train dataset by selecting the medical-folder data type, and inputting ixi-train as the tag. Then start the nodes.

Dataset for demograhics of the subjects

After selecting the folder that contains the patients for training the CLI will ask for CSV file where demographics of the patient are stored. These CSV files are named as `participants.csv`, and you can find these CSV files in the folder where the subject folders are located e.g `Guys/train/participant.csv`.

If you don't know how to add datasets to a node, or start a node, please read our user guide or follow the basic tutorial.

Define a new Strategy¶

Fed-BioMed's default strategy reads the number of samples per node through the shape parameter that is computed when the data is uploaded. For technical reasons, we need to change this to account for the fact that different modalities may have been used during the experiment.

In [ ]:
Copied!
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
from fedbiomed.common.constants import ErrorNumbers
from fedbiomed.common.exceptions import FedbiomedStrategyError

class MedicalFolderStrategy(DefaultStrategy):
    def __init__(self, data, modalities = ['T1']):
        super().__init__(data)
        self._modalities = modalities
        
    def refine(self, training_replies, round_i):
        models_params = []
        weights = []

        # check that all nodes answered
        cl_answered = [val['node_id'] for val in training_replies.data()]

        answers_count = 0
        for cl in self.sample_nodes(round_i):
            if cl in cl_answered:
                answers_count += 1
            else:
                # this node did not answer
                logger.error(f'{ErrorNumbers.FB408.value} (node = {cl})')

        if len(self.sample_nodes(round_i)) != answers_count:
            if answers_count == 0:
                # none of the nodes answered
                msg = ErrorNumbers.FB407.value

            else:
                msg = ErrorNumbers.FB408.value

            logger.critical(msg)
            raise FedbiomedStrategyError(msg)

        # check that all nodes that answer could successfully train
        self._success_node_history[round_i] = []
        all_success = True
        for tr in training_replies:
            if tr['success'] is True:
                model_params = tr['params']
                models_params.append(model_params)
                self._success_node_history[round_i].append(tr['node_id'])
            else:
                # node did not succeed
                all_success = False
                logger.error(f'{ErrorNumbers.FB409.value} (node = {tr["node_id"]})')

        if not all_success:
            raise FedbiomedStrategyError(ErrorNumbers.FB402.value)

        # so far, everything is OK
        shapes = [sum(val[0]["shape"][modality][0] for modality in self._modalities) for (key, val) in self._fds.data().items()]
        totalrows = sum(shapes)
        weights = [x / totalrows for x in shapes]
        logger.info('Nodes that successfully reply in round ' +
                    str(round_i) + ' ' +
                    str(self._success_node_history[round_i]))
        return models_params, weights
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy from fedbiomed.common.constants import ErrorNumbers from fedbiomed.common.exceptions import FedbiomedStrategyError class MedicalFolderStrategy(DefaultStrategy): def __init__(self, data, modalities = ['T1']): super().__init__(data) self._modalities = modalities def refine(self, training_replies, round_i): models_params = [] weights = [] # check that all nodes answered cl_answered = [val['node_id'] for val in training_replies.data()] answers_count = 0 for cl in self.sample_nodes(round_i): if cl in cl_answered: answers_count += 1 else: # this node did not answer logger.error(f'{ErrorNumbers.FB408.value} (node = {cl})') if len(self.sample_nodes(round_i)) != answers_count: if answers_count == 0: # none of the nodes answered msg = ErrorNumbers.FB407.value else: msg = ErrorNumbers.FB408.value logger.critical(msg) raise FedbiomedStrategyError(msg) # check that all nodes that answer could successfully train self._success_node_history[round_i] = [] all_success = True for tr in training_replies: if tr['success'] is True: model_params = tr['params'] models_params.append(model_params) self._success_node_history[round_i].append(tr['node_id']) else: # node did not succeed all_success = False logger.error(f'{ErrorNumbers.FB409.value} (node = {tr["node_id"]})') if not all_success: raise FedbiomedStrategyError(ErrorNumbers.FB402.value) # so far, everything is OK shapes = [sum(val[0]["shape"][modality][0] for modality in self._modalities) for (key, val) in self._fds.data().items()] totalrows = sum(shapes) weights = [x / totalrows for x in shapes] logger.info('Nodes that successfully reply in round ' + str(round_i) + ' ' + str(self._success_node_history[round_i])) return models_params, weights

Create a Training Plan¶

We create a training plan that incorporates the UNet model. We rely on the unet package for simplicity. Please refer to the original package for more details about UNet: Pérez-García, Fernando. (2020). fepegar/unet: PyTorch implementation of 2D and 3D U-Net (v0.7.5). Zenodo. https://doi.org/10.5281/zenodo.3697931

Define the model via the init_model function¶

The init_model function must return a UNet instance. Please refer to the TrainingPlan documentation for more details.

Define the loss function via the training_step function¶

Loss function is computed based on the Dice Loss.

Carole H Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M Jorge Cardoso. Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, pages 240–248. Springer, 2017.

Define data loading and transformations via the training_data function¶

Within the training_data function, we create an instance of MedicalFolderDataset and pass it to Fed-BioMed's DataManager class.

To preprocess images, we define the image transformations for the input images and the labels leveraging MONAI's transforms. Note that we also include the correct dependencies in the init_dependencies function.

Additionally, we define a transformation for the demographics data contained in the associated csv file. In order to be able to use information extracted from the demographics data as inputs to UNet, we must convert it to a torch.Tensor object. To achieve this, we exploit the demographics_transform argument of the MedicalFolderDataset. The transformation defined in this tutorial is just for illustration purposes, it does little more than just extracting some variables from the tabular data and converting them to the appropriate format.

Define training step¶

Here we take as input one batch of (data, target), train the model and compute the loss function.

Note that the MedicalFolderDataset class returns data as a tuple of (images, demographics), where:

  • images is a dict of {modality: image} (after image transformations)
  • demographics is a dict of {column_name: values} where the column names are taken from the demographics csv file while the target is a dict of {modality: image} (after target transformations).

In our case, the modality used is T1 for the input images, while the modality used for the target is label. In this tutorial, we ignore the values of the demographics data during training because the UNet model only takes images as input. However, the code is provided for illustration purposes as it shows the recommended way to handle the associated tabular data.

In [ ]:
Copied!
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.logger import logger
from fedbiomed.common.data import DataManager, MedicalFolderDataset
import torch.nn as nn
from torch.optim import AdamW
from unet import UNet

class UNetTrainingPlan(TorchTrainingPlan):

    def init_model(self, model_args):
        model = self.Net(model_args)
        return model


    def init_optimizer(self):
        optimizer = AdamW(self.model().parameters())
        return optimizer

    def init_dependencies(self):
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
               "import torch.nn as nn",
               'import torch.nn.functional as F',
               "from fedbiomed.common.data import MedicalFolderDataset",
               'import numpy as np',
               'from torch.optim import AdamW',
               'from unet import UNet']
        return deps


    class Net(nn.Module):
        # Init of UNetTrainingPlan
        def __init__(self, model_args: dict = {}):
            super().__init__()
            self.CHANNELS_DIMENSION = 1

            self.unet = UNet(
                in_channels = model_args.get('in_channels',1),
                out_classes = model_args.get('out_classes',2),
                dimensions = model_args.get('dimensions',2),
                num_encoding_blocks = model_args.get('num_encoding_blocks',5),
                out_channels_first_layer = model_args.get('out_channels_first_layer',64),
                normalization = model_args.get('normalization', None),
                pooling_type = model_args.get('pooling_type', 'max'),
                upsampling_type = model_args.get('upsampling_type','conv'),
                preactivation = model_args.get('preactivation',False),
                residual = model_args.get('residual',False),
                padding = model_args.get('padding',0),
                padding_mode = model_args.get('padding_mode','zeros'),
                activation = model_args.get('activation','ReLU'),
                initial_dilation = model_args.get('initial_dilation',None),
                dropout = model_args.get('dropout',0),
                monte_carlo_dropout = model_args.get('monte_carlo_dropout',0)
            )

        def forward(self, x):
            x = self.unet.forward(x)
            x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
            return x

    @staticmethod
    def get_dice_loss(output, target, epsilon=1e-9):
        SPATIAL_DIMENSIONS = 2, 3, 4
        p0 = output
        g0 = target
        p1 = 1 - p0
        g1 = 1 - g0
        tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
        fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
        fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
        num = 2 * tp
        denom = 2 * tp + fp + fn + epsilon
        dice_score = num / denom
        return 1. - dice_score

    @staticmethod
    def demographics_transform(demographics: dict):
        """Transforms dict of demographics into data type for ML.

        This function is provided for demonstration purposes, but
        note that if you intend to use demographics data as part
        of your model's input, you **must** provide a
        `demographics_transform` function which at the very least
        converts the demographics dict into a torch.Tensor.

        Must return either a torch Tensor or something Tensor-like
        that can be easily converted through the torch.as_tensor()
        function."""

        if isinstance(demographics, dict) and len(demographics) == 0:
            # when input is empty dict, we don't want to transform anything
            return demographics

        # simple example: keep only some keys
        keys_to_keep = ['HEIGHT', 'WEIGHT']
        out = np.array([float(val) for key, val in demographics.items() if key in keys_to_keep])

        # more complex: generate dummy variables for site name
        # not ideal as it requires knowing the site names in advance
        # could be better implemented with some preprocess
        site_names = ['Guys', 'IOP', 'HH']
        len_dummy_vars = len(site_names) + 1
        dummy_vars = np.zeros(shape=(len_dummy_vars,))
        site_name = demographics['SITE_NAME']
        if site_name in site_names:
            site_idx = site_names.index(site_name)
        else:
            site_idx = len_dummy_vars - 1
        dummy_vars[site_idx] = 1.

        return np.concatenate((out, dummy_vars))


    def training_data(self,  batch_size = 4):
    # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (48, 60, 48)
        training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),])
        target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])

        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='T1',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=UNetTrainingPlan.demographics_transform)
        loader_arguments = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **loader_arguments)


    def training_step(self, data, target):
        #this function must return the loss to backward it
        img = data[0]['T1']
        demographics = data[1]
        output = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(output, target['label'])
        avg_loss = loss.mean()
        return avg_loss

    def testing_step(self, data, target):
        img = data[0]['T1']
        demographics = data[1]
        target = target['label']
        prediction = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        avg_loss = loss.mean()  # average per batch
        return avg_loss
from fedbiomed.common.training_plans import TorchTrainingPlan from fedbiomed.common.logger import logger from fedbiomed.common.data import DataManager, MedicalFolderDataset import torch.nn as nn from torch.optim import AdamW from unet import UNet class UNetTrainingPlan(TorchTrainingPlan): def init_model(self, model_args): model = self.Net(model_args) return model def init_optimizer(self): optimizer = AdamW(self.model().parameters()) return optimizer def init_dependencies(self): # Here we define the custom dependencies that will be needed by our custom Dataloader deps = ["from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)", "import torch.nn as nn", 'import torch.nn.functional as F', "from fedbiomed.common.data import MedicalFolderDataset", 'import numpy as np', 'from torch.optim import AdamW', 'from unet import UNet'] return deps class Net(nn.Module): # Init of UNetTrainingPlan def __init__(self, model_args: dict = {}): super().__init__() self.CHANNELS_DIMENSION = 1 self.unet = UNet( in_channels = model_args.get('in_channels',1), out_classes = model_args.get('out_classes',2), dimensions = model_args.get('dimensions',2), num_encoding_blocks = model_args.get('num_encoding_blocks',5), out_channels_first_layer = model_args.get('out_channels_first_layer',64), normalization = model_args.get('normalization', None), pooling_type = model_args.get('pooling_type', 'max'), upsampling_type = model_args.get('upsampling_type','conv'), preactivation = model_args.get('preactivation',False), residual = model_args.get('residual',False), padding = model_args.get('padding',0), padding_mode = model_args.get('padding_mode','zeros'), activation = model_args.get('activation','ReLU'), initial_dilation = model_args.get('initial_dilation',None), dropout = model_args.get('dropout',0), monte_carlo_dropout = model_args.get('monte_carlo_dropout',0) ) def forward(self, x): x = self.unet.forward(x) x = F.softmax(x, dim=self.CHANNELS_DIMENSION) return x @staticmethod def get_dice_loss(output, target, epsilon=1e-9): SPATIAL_DIMENSIONS = 2, 3, 4 p0 = output g0 = target p1 = 1 - p0 g1 = 1 - g0 tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS) fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS) fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS) num = 2 * tp denom = 2 * tp + fp + fn + epsilon dice_score = num / denom return 1. - dice_score @staticmethod def demographics_transform(demographics: dict): """Transforms dict of demographics into data type for ML. This function is provided for demonstration purposes, but note that if you intend to use demographics data as part of your model's input, you **must** provide a `demographics_transform` function which at the very least converts the demographics dict into a torch.Tensor. Must return either a torch Tensor or something Tensor-like that can be easily converted through the torch.as_tensor() function.""" if isinstance(demographics, dict) and len(demographics) == 0: # when input is empty dict, we don't want to transform anything return demographics # simple example: keep only some keys keys_to_keep = ['HEIGHT', 'WEIGHT'] out = np.array([float(val) for key, val in demographics.items() if key in keys_to_keep]) # more complex: generate dummy variables for site name # not ideal as it requires knowing the site names in advance # could be better implemented with some preprocess site_names = ['Guys', 'IOP', 'HH'] len_dummy_vars = len(site_names) + 1 dummy_vars = np.zeros(shape=(len_dummy_vars,)) site_name = demographics['SITE_NAME'] if site_name in site_names: site_idx = site_names.index(site_name) else: site_idx = len_dummy_vars - 1 dummy_vars[site_idx] = 1. return np.concatenate((out, dummy_vars)) def training_data(self, batch_size = 4): # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed common_shape = (48, 60, 48) training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),]) target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)]) dataset = MedicalFolderDataset( root=self.dataset_path, data_modalities='T1', target_modalities='label', transform=training_transform, target_transform=target_transform, demographics_transform=UNetTrainingPlan.demographics_transform) loader_arguments = {'batch_size': batch_size, 'shuffle': True} return DataManager(dataset, **loader_arguments) def training_step(self, data, target): #this function must return the loss to backward it img = data[0]['T1'] demographics = data[1] output = self.model().forward(img) loss = UNetTrainingPlan.get_dice_loss(output, target['label']) avg_loss = loss.mean() return avg_loss def testing_step(self, data, target): img = data[0]['T1'] demographics = data[1] target = target['label'] prediction = self.model().forward(img) loss = UNetTrainingPlan.get_dice_loss(prediction, target) avg_loss = loss.mean() # average per batch return avg_loss

Prepare the experiment¶

In [ ]:
Copied!
model_args = {
    'in_channels': 1,
    'out_classes': 2,
    'dimensions': 3,
    'num_encoding_blocks': 3,
    'out_channels_first_layer': 8,
    'normalization': 'batch',
    'upsampling_type': 'linear',
    'padding': True,
    'activation': 'PReLU',
}

training_args = {
    'batch_size': 16, 
    'epochs': 2, 
    'dry_run': False,
    'log_interval': 2,
    'test_ratio' : 0.1,
    'test_on_global_updates': True,
    'test_on_local_updates': True,
}
model_args = { 'in_channels': 1, 'out_classes': 2, 'dimensions': 3, 'num_encoding_blocks': 3, 'out_channels_first_layer': 8, 'normalization': 'batch', 'upsampling_type': 'linear', 'padding': True, 'activation': 'PReLU', } training_args = { 'batch_size': 16, 'epochs': 2, 'dry_run': False, 'log_interval': 2, 'test_ratio' : 0.1, 'test_on_global_updates': True, 'test_on_local_updates': True, }
In [ ]:
Copied!
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['ixi-train']
num_rounds = 3

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=UNetTrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 tensorboard=True
                )
medical_folder_strategy = MedicalFolderStrategy(exp._fds, modalities=['T1'])
_ = exp.set_strategy(node_selection_strategy=medical_folder_strategy)
from fedbiomed.researcher.experiment import Experiment from fedbiomed.researcher.aggregators.fedavg import FedAverage tags = ['ixi-train'] num_rounds = 3 exp = Experiment(tags=tags, model_args=model_args, training_plan_class=UNetTrainingPlan, training_args=training_args, round_limit=num_rounds, aggregator=FedAverage(), tensorboard=True ) medical_folder_strategy = MedicalFolderStrategy(exp._fds, modalities=['T1']) _ = exp.set_strategy(node_selection_strategy=medical_folder_strategy)
In [ ]:
Copied!
exp.training_plan_file()
exp.training_plan_file()

Tensorboard setup¶

In [ ]:
Copied!
%load_ext tensorboard
%load_ext tensorboard
In [ ]:
Copied!
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
from fedbiomed.researcher.environ import environ tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
In [ ]:
Copied!
tensorboard --logdir "$tensorboard_dir"
tensorboard --logdir "$tensorboard_dir"

On a Macbook Pro from 2015 with a 2,5 GHz Quad-Core Intel Core i7 processor and 16GB of DRAM, training for 3 rounds of 2 epochs each took about 30 minutes. The final training curves look like this:

image.png

Run the experiment¶

In [ ]:
Copied!
exp.run()
exp.run()

Validate on a local holdout set¶

To ensure consistency and simplify our life, we try to reuse the already-available code as much as possible. Note that this process assumes that the held-out data is stored locally on the machine.

Create an instance of the global model¶

First, we create an instance of the model using the parameters from the latest aggregation round.

In [ ]:
Copied!
local_training_plan = UNetTrainingPlan()
local_model = local_training_plan.init_model(model_args)
local_training_plan = UNetTrainingPlan() local_model = local_training_plan.init_model(model_args)
In [ ]:
Copied!
for dependency_statement in local_training_plan.init_dependencies():
    exec(dependency_statement)
for dependency_statement in local_training_plan.init_dependencies(): exec(dependency_statement)
In [ ]:
Copied!
local_model.load_state_dict(exp.aggregated_params()[num_rounds-1]['params'])
local_model.load_state_dict(exp.aggregated_params()[num_rounds-1]['params'])

Define a validation data loader¶

We extract the validation data loader from the training plan as well. This requires some knowledge about the internals of the MedicalFolderDataset class. At the end of the process, calling the split function with a ratio of 0 will return a data loader that loads all of the data.

In [ ]:
Copied!
from torch.utils.data import DataLoader
dataset_parameters = {
    'tabular_file': '../data/Hospital-Centers/Guys/holdout/participants.csv',
    'index_col': 14
}
local_training_plan.dataset_path = '../data/Hospital-Centers/Guys/holdout/'
val_data_manager = local_training_plan.training_data(batch_size=4)
val_data_manager._dataset.set_dataset_parameters(dataset_parameters)
val_data_loader = DataLoader(val_data_manager._dataset)
from torch.utils.data import DataLoader dataset_parameters = { 'tabular_file': '../data/Hospital-Centers/Guys/holdout/participants.csv', 'index_col': 14 } local_training_plan.dataset_path = '../data/Hospital-Centers/Guys/holdout/' val_data_manager = local_training_plan.training_data(batch_size=4) val_data_manager._dataset.set_dataset_parameters(dataset_parameters) val_data_loader = DataLoader(val_data_manager._dataset)

Compute the loss on validation images¶

In [ ]:
Copied!
losses = []
local_model.eval()

import torchwith torch.no_grad():
    for (images, demographics), targets in val_data_loader:
        image = images['T1']
        target = targets['label']
        prediction = local_model.forward(image)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        losses.append(loss)
losses = [] local_model.eval() import torchwith torch.no_grad(): for (images, demographics), targets in val_data_loader: image = images['T1'] target = targets['label'] prediction = local_model.forward(image) loss = UNetTrainingPlan.get_dice_loss(prediction, target) losses.append(loss)

Visualize the outputs¶

As a bonus, we visualize the outputs of our model on the holdout dataset.

In [ ]:
Copied!
one_batch = next(iter(val_data_loader))
one_batch = next(iter(val_data_loader))

one_batch contains both input features and labels. Both are 3D images, which can be accessed in the following way (k represents the height in the stack of images):

In [ ]:
Copied!
k = 24
one_batch[1]['label'][..., k].shape
k = 24 one_batch[1]['label'][..., k].shape
In [ ]:
Copied!
k = 24
one_batch[0][0]['T1'][..., k].shape
k = 24 one_batch[0][0]['T1'][..., k].shape
In [ ]:
Copied!
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.figsize'] = 12, 6
import torchvision
from IPython import display
import matplotlib.pyplot as plt %config InlineBackend.figure_format = 'retina' plt.rcParams['figure.figsize'] = 12, 6 import torchvision from IPython import display
In [ ]:
Copied!
k = 24
batch_mri = one_batch[0][0]['T1'][..., k]
batch_label = one_batch[1]['label'][:, 1:, ..., k]
slices = torch.cat((batch_mri, batch_label))
image_path = 'batch_whole_images.png'
torchvision.utils.save_image(
    slices,
    image_path,
    nrow=max(val_data_loader.batch_size//2,1),
    normalize=True,
    scale_each=True,
    padding=4,
)
display.Image(image_path, width=300)
k = 24 batch_mri = one_batch[0][0]['T1'][..., k] batch_label = one_batch[1]['label'][:, 1:, ..., k] slices = torch.cat((batch_mri, batch_label)) image_path = 'batch_whole_images.png' torchvision.utils.save_image( slices, image_path, nrow=max(val_data_loader.batch_size//2,1), normalize=True, scale_each=True, padding=4, ) display.Image(image_path, width=300)
Download Notebook
  • Automatic download and wrangling for the impatient
  • Details about data preparation
    • Folder structure for MedicalFolderDataset
    • Folder structure for this tutorial
  • Add the IXI dataset to the federated nodes
  • Define a new Strategy
  • Create a Training Plan
    • Define the model via the init_model function
    • Define the loss function via the training_step function
    • Define data loading and transformations via the training_data function
    • Define training step
  • Prepare the experiment
  • Tensorboard setup
  • Run the experiment
  • Validate on a local holdout set
    • Create an instance of the global model
    • Define a validation data loader
    • Compute the loss on validation images
    • Visualize the outputs
Address:

2004 Rte des Lucioles, 06902 Sophia Antipolis

E-mail:

fedbiomed _at_ inria _dot_ fr

Fed-BioMed © 2022