Federated 2d image classification with MONAI¶
Introduction¶
This tutorial shows how to deploy in Fed-BioMed the 2d image classification example provided in the project MONAI using MedNIST tutorial
Being MONAI based on PyTorch, the deployment within Fed-BioMed follows seamlessy the same general structure of general PyTorch models.
Following the MONAI example, this tutorial is based on the MedNIST dataset.
Creating MedNIST nodes¶
MedNIST provides an artificial 2d classification dataset created by gathering different medical imaging datasets from TCIA, the RSNA Bone Age Challenge, and the NIH Chest X-ray dataset. The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.
To proceed with the tutorial, we created an iid partitioning of the MedNIST dataset between 3 clients. Each client has 3000 image samples for each class. You can download the training partitions from here.
The dataset owned by each client has structure:
└── client_*/
├── AbdomenCT/
└── BreastMRI/
└── CXR/
└── ChestCT/
└── Hand/
└── HeadCT/
To create the federated dataset, we follow the standard procedure for node creation/population of Fed-BioMed. After activating the fedbiomed network with the commands
$ source ./scripts/fedbiomed_environment network
and
$ ./scripts/fedbiomed_run network
we create a first node by using the commands
$ source ./scripts/fedbiomed_environment node
$ ./scripts/fedbiomed_run node start
We then poulate the node with the data of first client:
$ ./scripts/fedbiomed_run node add
We select option 3 (images) to add MedNIST partition of client 1, by just picking the folder of client 1. Assign tag mednist
to the data when asked.
We can further check that the data has been added by executing:
$ ./scripts/fedbiomed_run node list
Following the same procedure, we create the other two nodes with the datasets of client 2 and client 3 respectively.
Running Fed-BioMed Researcher¶
We are now ready to start the reseracher enviroment with the command source ./scripts/fedbiomed_environment researcher
, and open the Jupyter notebook.
We can first quesry the network for the mednist dataset. In this case, the nodes are sharing the respective partitions unsing the same tag mednist
:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
2022-01-07 17:32:20,661 fedbiomed INFO - Component environment: 2022-01-07 17:32:20,662 fedbiomed INFO - - type = ComponentType.RESEARCHER 2022-01-07 17:32:21,125 fedbiomed INFO - Messaging researcher_92feca42-d5ac-4555-9db3-d676f5fec16b successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x108bab460> 2022-01-07 17:32:21,161 fedbiomed INFO - Listing available datasets in all nodes... 2022-01-07 17:32:21,169 fedbiomed INFO - log from: node_267dfb38-d101-4107-a93e-8fa21ae92d11 / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'command': 'list'} 2022-01-07 17:32:21,170 fedbiomed INFO - log from: node_706bed89-d9e8-48ea-a024-94507f7b7baf / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'command': 'list'} 2022-01-07 17:32:21,174 fedbiomed INFO - log from: node_5d88d235-9341-4495-ad13-3d129aeaa30e / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'command': 'list'} 2022-01-07 17:32:31,166 fedbiomed INFO - Node: node_267dfb38-d101-4107-a93e-8fa21ae92d11 | Number of Datasets: 1 +---------+-------------+-------------+---------------+--------------------+ | name | data_type | tags | description | shape | +=========+=============+=============+===============+====================+ | mednist | images | ['mednist'] | bla | [18000, 3, 64, 64] | +---------+-------------+-------------+---------------+--------------------+ 2022-01-07 17:32:31,167 fedbiomed INFO - Node: node_706bed89-d9e8-48ea-a024-94507f7b7baf | Number of Datasets: 1 +---------+-------------+-------------+---------------+--------------------+ | name | data_type | tags | description | shape | +=========+=============+=============+===============+====================+ | mednist | images | ['mednist'] | bla | [18000, 3, 64, 64] | +---------+-------------+-------------+---------------+--------------------+ 2022-01-07 17:32:31,168 fedbiomed INFO - Node: node_5d88d235-9341-4495-ad13-3d129aeaa30e | Number of Datasets: 1 +---------+-------------+-------------+---------------+--------------------+ | name | data_type | tags | description | shape | +=========+=============+=============+===============+====================+ | mednist | images | ['mednist'] | bla | [16954, 3, 64, 64] | +---------+-------------+-------------+---------------+--------------------+
{'node_267dfb38-d101-4107-a93e-8fa21ae92d11': [{'name': 'mednist', 'data_type': 'images', 'tags': ['mednist'], 'description': 'bla', 'shape': [18000, 3, 64, 64]}], 'node_706bed89-d9e8-48ea-a024-94507f7b7baf': [{'name': 'mednist', 'data_type': 'images', 'tags': ['mednist'], 'description': 'bla', 'shape': [18000, 3, 64, 64]}], 'node_5d88d235-9341-4495-ad13-3d129aeaa30e': [{'name': 'mednist', 'data_type': 'images', 'tags': ['mednist'], 'description': 'bla', 'shape': [16954, 3, 64, 64]}]}
Create an experiment to train a model on the data found¶
The code for network and data loader of the MONAI tutorial can now be deployed in Fed-BioMed. We first import the necessary modules from fedbiomed
and monai
libraries:
from fedbiomed.researcher.environ import environ
import tempfile
import os
tmp_dir_model = tempfile.TemporaryDirectory(dir=environ['TMP_DIR']+os.sep)
model_file = os.path.join(tmp_dir_model.name, 'class_export_mednist.py')
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
Activations,
AddChannel,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
)
from monai.utils import set_determinism
We can now define the training plan. Note that we can simply use the standard TorchTrainingPlan
natively provided in Fed-BioMed. We reuse the MedNISTDataset
data loader defined in the original MONAI tutorial, which is returned by the method training_data
, which also implements the data parsing from the nodes dataset_path
. Following the MONAI tutorial, the model is the DenseNet121
.
%%writefile "$model_file"
import os
import numpy as np
import torch
import torch.nn as nn
from fedbiomed.common.torchnn import TorchTrainingPlan
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
Activations,
AddChannel,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
)
from monai.utils import set_determinism
# Here we define the model to be used.
# You can use any class name (here 'DenseNet121')
class MyTrainingPlan(TorchTrainingPlan):
def __init__(self, model_args: dict = {}):
super(MyTrainingPlan, self).__init__(model_args)
# Here we define the custom dependencies that will be needed by our custom Dataloader
# In this case, we need the torch DataLoader classes
# Since we will train on MNIST, we need datasets and transform from torchvision
deps = ["import numpy as np",
"import os",
"from torch.utils.data import DataLoader",
"from monai.apps import download_and_extract",
"from monai.config import print_config",
"from monai.data import decollate_batch",
"from monai.metrics import ROCAUCMetric",
"from monai.networks.nets import DenseNet121",
"from monai.transforms import ( Activations, AddChannel, AsDiscrete, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, EnsureType, )",
"from monai.utils import set_determinism",]
self.add_dependency(deps)
self.num_class = model_args['num_class']
self.model = DenseNet121(spatial_dims=2, in_channels=1,
out_channels = self.num_class)
self.loss_function = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]
def parse_data(self, path):
print(self.dataset_path)
class_names = sorted(x for x in os.listdir(path)
if os.path.isdir(os.path.join(path, x)))
num_class = len(class_names)
image_files = [
[
os.path.join(path, class_names[i], x)
for x in os.listdir(os.path.join(path, class_names[i]))
]
for i in range(num_class)
]
return image_files, num_class
def training_data(self, batch_size = 48):
self.image_files, num_class = self.parse_data(self.dataset_path)
if self.num_class!=num_class:
raise Exception('number of available classes does not match declared classes')
num_each = [len(self.image_files[i]) for i in range(self.num_class)]
image_files_list = []
image_class = []
for i in range(self.num_class):
image_files_list.extend(self.image_files[i])
image_class.extend([i] * num_each[i])
num_total = len(image_class)
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)
val_split = int(1. * length)
train_indices = indices[:val_split]
train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
train_transforms = Compose(
[
LoadImage(image_only=True),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
val_transforms = Compose(
[LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])
y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=num_class)])
print(
f"Training count: {len(train_x)}")
train_ds = self.MedNISTDataset(train_x, train_y, train_transforms)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size, shuffle=True)
return train_loader
def training_step(self, data, target):
output = self.forward(data)
loss = self.loss_function(output, target)
return loss
Writing /Users/mlorenzi/works/temp/fedbiomed/var/tmp/tmpf5gbxg1r/class_export_mednist.py
We now set the model and training parameters. Note that we use only 1 epoch for this experiment, and perform the training on ~26% of the locally available training data.
model_args = {'num_class':6,}
training_args = {
'batch_size': 20,
'lr': 1e-5,
'epochs': 1,
'dry_run': False,
'batch_maxnum':250 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
The experiment can be now defined, by providing the mednist
tag, and running the local training on nodes with model defined in model_path
, standard aggregator
(FedAvg) and client_selection_strategy
(all nodes used). Federated learning is going to be perfomed through 3 optimization rounds.
For running this experiment, you need a computer with the following specifications
- more than 16 GB of RAM
- 2.5 GHz processor or higher, with at least 4 cores
If your computer specification are lower, you can reduce the number of data passed when training model (set batchnum
from 250 to 25) and the number of rounds
(from 3 to 1) but model performances may decrease dramatically
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
tags = ['mednist']
rounds = 3
exp = Experiment(tags=tags,
model_path=model_file,
model_args=model_args,
model_class='MyTrainingPlan',
training_args=training_args,
round_limit=rounds,
aggregator=FedAverage(),
node_selection_strategy=None
)
2022-01-07 17:32:33,970 fedbiomed INFO - Searching dataset with data tags: ['mednist'] for all nodes 2022-01-07 17:32:33,978 fedbiomed INFO - log from: node_706bed89-d9e8-48ea-a024-94507f7b7baf / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'tags': ['mednist'], 'command': 'search'} 2022-01-07 17:32:33,979 fedbiomed INFO - log from: node_267dfb38-d101-4107-a93e-8fa21ae92d11 / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'tags': ['mednist'], 'command': 'search'} 2022-01-07 17:32:33,980 fedbiomed INFO - log from: node_5d88d235-9341-4495-ad13-3d129aeaa30e / DEBUG - Message received: {'researcher_id': 'researcher_92feca42-d5ac-4555-9db3-d676f5fec16b', 'tags': ['mednist'], 'command': 'search'} 2022-01-07 17:32:43,975 fedbiomed INFO - Node selected for training -> node_706bed89-d9e8-48ea-a024-94507f7b7baf 2022-01-07 17:32:43,975 fedbiomed INFO - Node selected for training -> node_267dfb38-d101-4107-a93e-8fa21ae92d11 2022-01-07 17:32:43,976 fedbiomed INFO - Node selected for training -> node_5d88d235-9341-4495-ad13-3d129aeaa30e 2022-01-07 17:32:43,977 fedbiomed INFO - Checking data quality of federated datasets... 2022-01-07 17:32:44,134 fedbiomed DEBUG - torchnn saved model filename: /Users/mlorenzi/works/temp/fedbiomed/var/experiments/Experiment_0002/my_model_18bf1810-e08f-4393-b705-819dde05205c.py
Let's start the experiment.
By default, this function doesn't stop until all the round_limit
rounds are done for all the clients
exp.run()
Testing¶
Once the federated model is obtained, it is possible to test it locally on an independent testing partition. The test dataset is available at this link:
https://drive.google.com/file/d/1YbwA0WitMoucoIa_Qao7IC1haPfDp-XD/
!pip install gdown
import os
import shutil
import tempfile
import PIL
import torch
import numpy as np
from sklearn.metrics import classification_report
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
import zipfile
from monai.transforms import (
Activations,
AddChannel,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
)
from monai.utils import set_determinism
print_config()
MONAI version: 0.8.0 Numpy version: 1.22.0 Pytorch version: 1.8.1 MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: 714d00dffe6653e21260160666c4c201ab66511b Optional dependencies: Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION. Nibabel version: NOT INSTALLED or UNKNOWN VERSION. scikit-image version: NOT INSTALLED or UNKNOWN VERSION. Pillow version: 9.0.0 Tensorboard version: 2.7.0 gdown version: 4.2.0 TorchVision version: 0.9.1 tqdm version: 4.62.3 lmdb version: NOT INSTALLED or UNKNOWN VERSION. psutil version: NOT INSTALLED or UNKNOWN VERSION. pandas version: 1.3.5 einops version: NOT INSTALLED or UNKNOWN VERSION. transformers version: NOT INSTALLED or UNKNOWN VERSION. mlflow version: NOT INSTALLED or UNKNOWN VERSION. For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
Download the testing dataset on the local temporary folder.
import gdown
import zipfile
resource = "https://drive.google.com/uc?id=1YbwA0WitMoucoIa_Qao7IC1haPfDp-XD"
base_dir = tmp_dir_model.name
test_file = os.path.join(base_dir, "MedNIST_testing.zip")
gdown.download(resource, test_file, quiet=False)
zf = zipfile.ZipFile(test_file)
for file in zf.infolist():
zf.extract(file, base_dir)
data_dir = os.path.join(base_dir, "MedNIST_testing")
Downloading... From: https://drive.google.com/uc?id=1YbwA0WitMoucoIa_Qao7IC1haPfDp-XD To: /Users/mlorenzi/works/temp/fedbiomed/var/tmp/tmpf5gbxg1r/MedNIST_testing.zip 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.50M/9.50M [00:00<00:00, 35.1MB/s]
Parse the data and create the testing data loader:
class_names = sorted(x for x in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, x)))
num_class = len(class_names)
image_files = [
[
os.path.join(data_dir, class_names[i], x)
for x in os.listdir(os.path.join(data_dir, class_names[i]))
]
for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
image_files_list.extend(image_files[i])
image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size
print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")
Total image count: 6000 Image dimensions: 64 x 64 Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT'] Label counts: [1000, 1000, 1000, 1000, 1000, 1000]
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)
test_split = int(0.1 * length)
test_indices = indices[:test_split]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]
val_transforms = Compose(
[LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])
y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=num_class)])
class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]
test_ds = MedNISTDataset(test_x, test_y, val_transforms)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=300)
Define testing metric:
auc_metric = ROCAUCMetric()
To test the federated model we need to create a model instance and assign to it the model parameters estimated at the last federated optimization round.
model = exp.model_instance()
model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])
<All keys matched successfully>
Compute the testing performance:
y_true = []
y_pred = []
with torch.no_grad():
for test_data in test_loader:
test_images, test_labels = (
test_data[0],
test_data[1],
)
pred = model(test_images).argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_pred.append(pred[i].item())
print(classification_report(
y_true, y_pred, target_names=class_names, digits=4))
precision recall f1-score support AbdomenCT 1.0000 0.9901 0.9950 101 BreastMRI 1.0000 1.0000 1.0000 95 CXR 1.0000 1.0000 1.0000 105 ChestCT 0.9890 1.0000 0.9945 90 Hand 1.0000 1.0000 1.0000 99 HeadCT 1.0000 1.0000 1.0000 110 accuracy 0.9983 600 macro avg 0.9982 0.9983 0.9983 600 weighted avg 0.9984 0.9983 0.9983 600
In spite of the relatively small training performed on the data shared in the 3 nodes, the performance of the federated model seems pretty good. Well done!