The Training Plan

A training plan is a class that defines the federated model training. It is responsible of providing base methods which allow every node to perform the training process. In Fed-BioMed, each ML framework provides its training plan which will be inherited by the child class that you create for your custom model. For instance, PyTorch's training plan requires inheriting the torch.nn.Module, while the training plan for Scikit-Learn only needs dependencies based on the method that is going to be used for training. In addition, their training processes will also differ. Therefore, Fed-BioMed provides different training plans for different ML frameworks.

The constructor method (__init__)

The constructor method of the training plan allows you to initialize your model class based on the attributes that you are going to use for your model. For instance, if you are working on PyTorch you can define your layers in the constructor method as you did for the general PyTorch model. Since your model will be trained on different nodes, the constructor method enables getting model arguments which are sent to the nodes by the experiment. Therefore, you can initialize your custom arguments to define your model. An example constructor is shown below for PyTorch.

class MyTrainingPlan(TorchTrainingPlan):       
    def __init__(self, kwargs):
        super(MyTrainingPlan, self).__init__()
        # kwargs should match the model arguments to be passed below to the experiment class
        self.in_features = kwargs['in_features']
        self.out_features = kwargs['out_features']
        self.fc1 = nn.Linear(self.in_features, 5)
        self.fc2 = nn.Linear(5, self.out_features)

forward

The forward method is where the feed-forward operation is defined. This method has to be defined by the researcher to be able to get input data and return the output throughout the network layers. It is necessary for model training under the PyTorch framework.

training_step

The training_step method of the training class defines how the cost is computed by forwarding input values through the network and using the loss function. It should output model error for model backpropagation based on neural network. By default, it is not defined in the parent TrainingPlan class: it should be defined by the researcher in his/her model class, same as the forward method. An example of training step for PyTorch is shown below.

    def training_step(self, data, target):
        output = self.forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss

training_data

The method training_data defines how datasets have to be loaded in nodes to make them ready for the training process. By default, This method stays as a reference method and it needs to be overwritten by the child class that you will create. This is because datasets used for model training vary depending on the problem.

In the previous code snippet, you can see that the training routine calls the training_data method to obtain data to perform the training step.

    # .....

    training_data = self.training_data(batch_size=batch_size)
    for batch_idx, (data, target) in enumerate(training_data):
        self.train()
        data, target = data.to(self.device), target.to(self.device)
        self.optimizer.zero_grad()
        res = self.training_step(data, target)

        # .... 

You can read the documentation for "training data" to learn more about varius use cases.

training_routine

The training routine is the heart of the training plan. This method orchestrates the model training in each node, based on given model and training arguments. For example, if the model is a neural network based on the PyTorch framework, the training routine is in charge of performing the training part over looping epochs and batches, the same as the normal PyTorch training. If you are considering a Scikit-Learn model, it just fits the model by the given ML method. The training routine is executed by the nodes after they have built the model class in their workspace.

As you can see from the following code snippet, the training routine requires some training arguments such as epochs, lr, batch_size etc. Since the training_routine is already defined by Fed-BioMed, you can only control the training process by changing these arguments. Modifying the training routine from the model class that inherits TrainingPlan might raise an unexpected error. Therefore, it is recommended to keep the predefined arguments. These arguments are passed to the node by the experiment through the network. In the next article, you will find more details about the experiment of Fed-BioMed.

 def training_routine(self,
                         epochs: int = 2,
                         log_interval: int = 10,
                         lr: Union[int, float] = 1e-3,
                         batch_size: int = 48,
                         batch_maxnum: int = 0,
                         dry_run: bool = False,
                         logger=None):

        # You can see details from `fedibomed.common.torchnn`
        # .....

        for epoch in range(1, epochs + 1):
            training_data = self.training_data(batch_size=batch_size)
            for batch_idx, (data, target) in enumerate(training_data):
                self.train() 
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                res = self.training_step(data, target)
                res.backward()
                self.optimizer.step()

                #.....

Method for Adding Dependencies

The method add_dependencies allows to include module dependencies needed by your custom training model. It is defined in the parent class of TrainingPlan for both PyTorch and Scikit-Learn. Dependencies to be added are passed as a list.

# Provided by Fed-BioMed // Necessary to save the model code into a file
def add_dependency(self, dep: List[str]):
    """adds extra python import(s)

    Args:
        dep (List[str]): package name import, eg: 'import torch as th'
    """
    self.dependencies.extend(dep)
    pass

And this is how you can use this method in the constructor method of your custom training plan.

class MyTrainingPlan(TorchTrainingPlan):
    def __init__(self):
        super(MyTrainingPlan, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        # Here we define the custom dependencies that will be needed by our custom Dataloader
        # In this case, we need the torch DataLoader classes
        # Since we will train on MNIST, we need datasets and transform from torchvision
        deps = ["from torchvision import datasets, transforms",
               "from torch.utils.data import DataLoader"]
        self.add_dependency(deps)

Saving and Loading Model

Each training plan provides save and load functionality. These are required for loading and saving model parameters into the file system during training in both the nodes and the researcher part. Consequently, experiment can upload and download the model parameters. These methods are provided by Fed-BioMed in the sub-TrainingPlan class: you may need to define or extend these methods while creating your training plan (model class), also based on the considered framework. Indeed, each framework has its own way to load and save models. Below, the saving and loading methods of the TorchTrainingPlan is shown.

    # provided by Fed-BioMed
    def save(self, filename, params: dict = None) -> None:

        if params is not None:
            return(torch.save(params, filename))
        else:
            return torch.save(self.state_dict(), filename)

    # provided by Fed-BioMed
    def load(self, filename: str, to_params: bool = False) -> dict:

        params = torch.load(filename)
        if to_params is False:
            self.load_state_dict(params)
        return params

For Scikit-Learn:

    def save(self, filename, params: dict=None):
        file = open(filename, "wb")
        if params is None:
            dump(self.m, file)
        else:
            if params.get('model_params') is not None: # called in the Round
                for p in params['model_params'].keys():
                    setattr(self.m, p, params['model_params'][p])
            else:
                for p in params.keys():
                    setattr(self.m, p, params[p])
            dump(self.m, file)
        file.close()

    def load(self, filename, to_params: bool = False):

        di_ret = {}
        file = open( filename , "rb")
        if not to_params:
            self.m = load(file)
            di_ret =  self.m
        else:
            self.m =  load(file)
            di_ret['model_params'] = {key: getattr(self.m, key) for key in self.param_list}
        file.close()
        return di_ret

You can access these classes from the fedbiomed/common directory to see them in more detail.