fedbiomed.researcher.aggregators

Module: fedbiomed.researcher.aggregators

Classes

Aggregator

CLASS
Aggregator()

Defines methods for aggregating strategy (eg FedAvg, FedProx, SCAFFOLD, ...).

Source code in fedbiomed/researcher/aggregators/aggregator.py
def __init__(self):
    self._aggregator_args: dict = None
    self._fds: FederatedDataSet = None
    self._training_plan_type: TrainingPlans = None
    self._secagg_crypter = SecaggCrypter()

Functions

aggregate(model_params, weights, args, kwargs)

Strategy to aggregate models

Parameters:

Name Type Description Default
model_params list

List of model parameters received from each node

required
weights list

Weight for each node-model-parameter set

required

Raises:

Type Description
FedbiomedAggregatorError

If the method is not defined by inheritor

Source code in fedbiomed/researcher/aggregators/aggregator.py
def aggregate(self, model_params: list, weights: list, *args, **kwargs) -> Dict:
    """
    Strategy to aggregate models

    Args:
        model_params: List of model parameters received from each node
        weights: Weight for each node-model-parameter set

    Raises:
        FedbiomedAggregatorError: If the method is not defined by inheritor
    """
    msg = ErrorNumbers.FB401.value + \
        ": aggregate method should be overloaded by the choosen strategy"
    logger.critical(msg)
    raise FedbiomedAggregatorError(msg)
check_values(args, kwargs)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def check_values(self, *args, **kwargs) -> True:
    return True
create_aggregator_args(args, kwargs)

Returns aggregator arguments that are expecting by the nodes

Returns:

Name Type Description
dict dict

contains Aggregator parameters that will be sent through MQTT message service

dict dict

contains parameters that will be sent through file exchange message. Both dictionaries are mapping node_id to 'Aggregator` parameters specific to each Node.

Source code in fedbiomed/researcher/aggregators/aggregator.py
def create_aggregator_args(self, *args, **kwargs) -> Tuple[dict, dict]:
    """Returns aggregator arguments that are expecting by the nodes

    Returns:
        dict: contains `Aggregator` parameters that will be sent through MQTT message
                service
        dict: contains parameters that will be sent through file exchange message.
                Both dictionaries are mapping node_id to 'Aggregator` parameters specific
                to each Node.
    """
    return self._aggregator_args or {}, {}
load_state(state, kwargs)

use for breakpoints. load the aggregator state

Source code in fedbiomed/researcher/aggregators/aggregator.py
def load_state(self, state: Dict[str, Any], **kwargs) -> None:
    """
    use for breakpoints. load the aggregator state
    """
    self._aggregator_args = state['parameters']
save_state(breakpoint_path=None, aggregator_args_create)

use for breakpoints. save the aggregator state

Source code in fedbiomed/researcher/aggregators/aggregator.py
def save_state(
    self,
    breakpoint_path: Optional[str] = None,
    **aggregator_args_create: Any,
) -> Dict[str, Any]:
    """
    use for breakpoints. save the aggregator state
    """
    aggregator_args_thr_msg, aggregator_args_thr_files = self.create_aggregator_args(**aggregator_args_create)
    if aggregator_args_thr_msg:
        if self._aggregator_args is None:
            self._aggregator_args = {}
        self._aggregator_args.update(aggregator_args_thr_msg)
        # aggregator_args = copy.deepcopy(self._aggregator_args)
        if breakpoint_path is not None and aggregator_args_thr_files:
            for node_id, node_arg in aggregator_args_thr_files.items():
                if isinstance(node_arg, dict):

                    for arg_name, aggregator_arg in node_arg.items():
                        if arg_name != 'aggregator_name': # do not save `aggregator_name` as a file
                            filename = self._save_arg_to_file(breakpoint_path, arg_name, node_id, aggregator_arg)
                            self._aggregator_args.setdefault(arg_name, {})


                            self._aggregator_args[arg_name][node_id] = filename  # replacing value by a path towards a file
                else:
                    filename = self._save_arg_to_file(breakpoint_path, arg_name, node_id, node_arg)
                    self._aggregator_args[arg_name] = filename
    state = {
        "class": type(self).__name__,
        "module": self.__module__,
        "parameters": self._aggregator_args
    }
    return state
secure_aggregation(params, encryption_factors, secagg_random, aggregation_round, total_sample_size, training_plan)

Apply aggregation for encrypted model parameters

Parameters:

Name Type Description Default
params List[List[int]]

List containing list of encrypted parameters of each node

required
encryption_factors List[Dict[str, List[int]]]

List of encrypted integers to validate encryption

required
secagg_random float

Randomly generated float value to validate secure aggregation correctness

required
aggregation_round int

The round of the aggregation.

required
total_sample_size int

Sum of sample sizes used for training

required
training_plan BaseTrainingPlan

Training plan instance used for the training.

required

Returns:

Type Description

aggregated model parameters

Source code in fedbiomed/researcher/aggregators/aggregator.py
def secure_aggregation(
        self,
        params: List[List[int]],
        encryption_factors: List[Dict[str, List[int]]],
        secagg_random: float,
        aggregation_round: int,
        total_sample_size: int,
        training_plan: 'BaseTrainingPlan'
):
    """ Apply aggregation for encrypted model parameters

    Args:
        params: List containing list of encrypted parameters of each node
        encryption_factors: List of encrypted integers to validate encryption
        secagg_random: Randomly generated float value to validate secure aggregation correctness
        aggregation_round: The round of the aggregation.
        total_sample_size: Sum of sample sizes used for training
        training_plan: Training plan instance used for the training.

    Returns:
        aggregated model parameters
    """

    # TODO: verify with secagg context number of parties
    num_nodes = len(params)

    # TODO: Use server key here
    key = -(len(params) * 10)

    # IMPORTANT = Keep this key for testing purposes
    key = -4521514305280526329525552501850970498079782904248225896786295610941010325354834129826500373412436986239012584207113747347251251180530850751209537684586944643780840182990869969844131477709433555348941386442841023261287875379985666260596635843322044109172782411303407030194453287409138194338286254652273563418119335656859169132074431378389356392955315045979603414700450628308979043208779867835835935403213000649039155952076869962677675951924910959437120608553858253906942559260892494214955907017206115207769238347962438107202114814163305602442458693305475834199715587932463252324681290310458316249381037969151400784780
    logger.info("Securely aggregating model parameters...")

    aggregate = functools.partial(self._secagg_crypter.aggregate,
                                  current_round=aggregation_round,
                                  num_nodes=num_nodes,
                                  key=key,
                                  total_sample_size=total_sample_size
                                  )
    # Validation
    encryption_factors = [f for k, f in encryption_factors.items()]
    validation: List[int] = aggregate(params=encryption_factors)

    if len(validation) != 1 or not math.isclose(validation[0], secagg_random, abs_tol=0.01):
        raise FedbiomedAggregatorError("Aggregation is failed due to incorrect decryption.")

    aggregated_params = aggregate(params=params)

    # Convert model params
    model = training_plan._model

    model_params = model.unflatten(aggregated_params)

    return model_params
set_fds(fds)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_fds(self, fds: FederatedDataSet) -> FederatedDataSet:
    self._fds = fds
    return self._fds
set_training_plan_type(training_plan_type)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
    self._training_plan_type = training_plan_type
    return self._training_plan_type

FedAverage

CLASS
FedAverage()

Bases: Aggregator

Defines the Federated averaging strategy

Source code in fedbiomed/researcher/aggregators/fedavg.py
def __init__(self):
    """Construct `FedAverage` object as an instance of [`Aggregator`]
    [fedbiomed.researcher.aggregators.Aggregator].
    """
    super(FedAverage, self).__init__()
    self.aggregator_name = "FedAverage"

Attributes

aggregator_name instance-attribute
aggregator_name = 'FedAverage'

Functions

aggregate(model_params, weights, args, kwargs)

Aggregates local models sent by participating nodes into a global model, following Federated Averaging strategy.

weights is a list of single-item dictionaries, each dictionary has the node id as key, and the weight as value. model_params is a list of single-item dictionaries, each dictionary has the node is as key, and a framework-specific representation of the model parameters as value.

Parameters:

Name Type Description Default
model_params Dict[str, Dict[str, Union[torch.Tensor, numpy.ndarray]]]

contains each model layers

required
weights Dict[str, float]

contains all weights of a given layer.

required

Returns:

Type Description
Mapping[str, Union[torch.Tensor, numpy.ndarray]]

Aggregated parameters

Source code in fedbiomed/researcher/aggregators/fedavg.py
def aggregate(
        self,
        model_params: Dict[str, Dict[str, Union['torch.Tensor', 'numpy.ndarray']]],
        weights: Dict[str, float],
        *args,
        **kwargs
) -> Mapping[str, Union['torch.Tensor', 'numpy.ndarray']]:
    """ Aggregates  local models sent by participating nodes into a global model, following Federated Averaging
    strategy.

    weights is a list of single-item dictionaries, each dictionary has the node id as key, and the weight as value.
    model_params is a list of single-item dictionaries, each dictionary has the node is as key,
    and a framework-specific representation of the model parameters as value.

    Args:
        model_params: contains each model layers
        weights: contains all weights of a given layer.

    Returns:
        Aggregated parameters
    """

    model_params_processed = []
    weights_processed = []

    for node_id, params in model_params.items():

        if node_id not in weights:
            raise FedbiomedAggregatorError(
                f"{ErrorNumbers.FB401.value}. Can not find corresponding calculated weight for the "
                f"node {node_id}. Aggregation is aborted."
            )

        weight = weights[node_id]
        model_params_processed.append(params)
        weights_processed.append(weight)

    if any([x < 0. or x > 1. for x in weights_processed]) or sum(weights_processed) == 0:
        raise FedbiomedAggregatorError(
            f"{ErrorNumbers.FB401.value}. Aggregation aborted due to sum of the weights is equal to 0 {weights}. "
            f"Sample sizes received from nodes might be corrupted."
        )

    agg_params = federated_averaging(model_params_processed, weights_processed)

    return agg_params

Scaffold

CLASS
Scaffold(server_lr=1.0, fds=None)

Bases: Aggregator

Defines the Scaffold strategy

Despite being an algorithm of choice for federated learning, it is observed that FedAvg suffers from client-drift when the data is heterogeneous (non-iid), resulting in unstable and slow convergence. SCAFFOLD uses control variates (variance reduction) to correct for the client-drift in its local updates. Intuitively, SCAFFOLD estimates the update direction for the server model (c) and the update direction for each client (c_i). The difference (c - c_i) is then an estimate of the client-drift which is used to correct the local update.

Fed-BioMed implementation details

Our implementation is heavily influenced by our design choice to prevent storing any state on the nodes between FL rounds. In particular, this means that the computation of the control variates (i.e. the correction states) needs to be performed centrally by the aggregator. Roughly, our implementation follows these steps (following the notation of the original Scaffold paper):

  1. let \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
  2. foreach(round):
  3. sample \( S \) nodes participating in this round out of \( N \) total
  4. the server communicates the global model \( \mathbf{x} \) and the correction states \( \delta_i \) to all clients
  5. parallel on each client
  6. initialize local model \( \mathbf{y}_i = \mathbf{x} \)
  7. foreach(update) until K updates have been performed
  8. obtain a data batch
  9. compute the gradients for this batch \( g(\mathbf{y}_i) \)
  10. apply correction term to gradients \( g(\mathbf{y}_i) -= \delta_i \)
  11. update model with one optimizer step e.g. for SGD \( \mathbf{y}_i -= \eta_i g(\mathbf{y}_i) \)
  12. end foreach(update)
  13. communicate updated model \( \mathbf{y}_i \) and learning rate \( \eta_i \)
  14. end parallel section on each client
  15. the server computes the node-wise model update \( \mathbf{\Delta y}_i = \mathbf{x} - \mathbf{y}_i \)
  16. the server updates the node-wise states \( \mathbf{c}_i = \delta_i + (\mathbf{\Delta y}_i) / (\eta_i K) \)
  17. the server updates the global state \( \mathbf{c} = (1/N) \sum_{i \in N} \mathbf{c}_i \)
  18. the server updates the node-wise correction state \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
  19. the server updates the global model by averaging \( \mathbf{x} = \mathbf{x} - (\eta/|S|) \sum_{i \in S} \mathbf{\Delta y}_i \)
  20. end foreach(round)

This diagram provides a visual representation of the algorithm.

References:

Attributes:

Name Type Description
aggregator_name str

name of the aggregator

server_lr float

value of the server learning rate

global_state Dict[str, Union[torch.Tensor, np.ndarray]]

a dictionary representing the global correction state \( \mathbf{c} \) in the format {parameter name: correction value}

nodes_states Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]]

a nested dictionary of correction parameters obtained for each client, in the format {node id: node-wise corrections}. The node-wise corrections are a dictionary in the format {parameter name: correction value} where the model parameters are those contained in each node's model.named_parameters().

nodes_deltas Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]]

a nested dictionary of deltas for each client, in the same format as nodes_states. The deltas are defined as \(\delta_i = \mathbf{c}_i - \mathbf{c} \)

nodes_lr Dict[str, List[float]]

dictionary of learning rates observed at end of the latest round, in the format {node id: learning rate}

Parameters:

Name Type Description Default
server_lr float

server's (or Researcher's) learning rate. Defaults to 1..

1.0
fds FederatedDataset

FederatedDataset obtained after a search request. Defaults to None.

None
Source code in fedbiomed/researcher/aggregators/scaffold.py
def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None):
    """Constructs `Scaffold` object as an instance of [`Aggregator`]
    [fedbiomed.researcher.aggregators.Aggregator].

    Args:
        server_lr (float): server's (or Researcher's) learning rate. Defaults to 1..
        fds (FederatedDataset, optional): FederatedDataset obtained after a `search` request. Defaults to None.

    """
    super().__init__()
    self.aggregator_name: str = "Scaffold"
    if server_lr == 0.:
        raise FedbiomedAggregatorError("SCAFFOLD Error: Server learning rate cannot be equal to 0")
    self.server_lr: float = server_lr
    self.global_state: Dict[str, Union[torch.Tensor, np.ndarray]] = {}
    self.nodes_states: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
    self.nodes_deltas: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
    self.nodes_lr: Dict[str, List[float]] = {}
    if fds is not None:
        self.set_fds(fds)
    self._aggregator_args = {}  # we need `_aggregator_args` to be not None

Attributes

aggregator_name instance-attribute
aggregator_name: str = 'Scaffold'
global_state instance-attribute
global_state: Dict[
    str, Union[torch.Tensor, np.ndarray]
] = {}
nodes_deltas instance-attribute
nodes_deltas: Dict[
    str, Dict[str, Union[torch.Tensor, np.ndarray]]
] = {}
nodes_lr instance-attribute
nodes_lr: Dict[str, List[float]] = {}
nodes_states instance-attribute
nodes_states: Dict[
    str, Dict[str, Union[torch.Tensor, np.ndarray]]
] = {}
server_lr instance-attribute
server_lr: float = server_lr

Functions

aggregate(model_params, weights, global_model, training_plan, training_replies, n_updates=1, n_round=0, args, kwargs)

Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option) [Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]

Performed computations:
  • Compute participating nodes' model update:
    • update_i = y_i - x
  • Compute aggregated model parameters:
    • x(+) = x - eta_g sum_S(update_i)
  • Update participating nodes' state:
    • c_i = delta_i + 1/(K*eta_i) * update_i
  • Update the global state and all nodes' correction state:
    • c = 1/N sum_{i=1}^n c_i
    • delta_i = (c_i - c)

where, according to paper notations c_i: local state variable for node i c: global state variable delta_i: (c_i - c), correction state for node i eta_g: server's learning rate eta_i: node i's learning rate N: total number of node participating to federated learning S: number of nodes considered during current round (S<=N) K: number of updates done during the round (ie number of data batches). x: global model parameters y_i: node i 's local model parameters at the end of the round

Parameters:

Name Type Description Default
model_params Dict

list of models parameters received from nodes

required
weights Dict[str, float]

weights depicting sample proportions available on each node. Unused for Scaffold.

required
global_model Dict[str, Union[torch.Tensor, np.ndarray]]

global model, ie aggregated model

required
training_plan BaseTrainingPlan

instance of TrainingPlan

required
training_replies Responses

Training replies from each node that participates in the current round

required
n_updates int

number of updates (number of batch performed). Defaults to 1.

1
n_round int

current round. Defaults to 0.

0

Returns:

Type Description
Dict

Aggregated parameters, as a dict mapping weigth names and values.

Raises:

Type Description
FedbiomedAggregatorError

If no FederatedDataset is attached to this Scaffold instance, or if node_ids do not belong to the dataset attached to it.

Source code in fedbiomed/researcher/aggregators/scaffold.py
def aggregate(self,
              model_params: Dict,
              weights: Dict[str, float],
              global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
              training_plan: BaseTrainingPlan,
              training_replies: Responses,
              n_updates: int = 1,
              n_round: int = 0,
              *args, **kwargs) -> Dict:
    """
    Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option)
    [Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]

    Performed computations:
    -----------------------

    - Compute participating nodes' model update:
        * update_i = y_i - x
    - Compute aggregated model parameters:
        * x(+) = x - eta_g sum_S(update_i)
    - Update participating nodes' state:
        * c_i = delta_i + 1/(K*eta_i) * update_i
    - Update the global state and all nodes' correction state:
        * c = 1/N sum_{i=1}^n c_i
        * delta_i = (c_i - c)

    where, according to paper notations
        c_i: local state variable for node `i`
        c: global state variable
        delta_i: (c_i - c), correction state for node `i`
        eta_g: server's learning rate
        eta_i: node i's learning rate
        N: total number of node participating to federated learning
        S: number of nodes considered during current round (S<=N)
        K: number of updates done during the round (ie number of data batches).
        x: global model parameters
        y_i: node i 's local model parameters at the end of the round

    Args:
        model_params: list of models parameters received from nodes
        weights: weights depicting sample proportions available
            on each node. Unused for Scaffold.
        global_model: global model, ie aggregated model
        training_plan (BaseTrainingPlan): instance of TrainingPlan
        training_replies: Training replies from each node that participates in the current round
        n_updates: number of updates (number of batch performed). Defaults to 1.
        n_round: current round. Defaults to 0.

    Returns:
        Aggregated parameters, as a dict mapping weigth names and values.

    Raises:
        FedbiomedAggregatorError: If no FederatedDataset is attached to this
            Scaffold instance, or if `node_ids` do not belong to the dataset
            attached to it.
    """
    # Gather the learning rates used by nodes, updating `self.nodes_lr`.
    self.set_nodes_learning_rate_after_training(training_plan, training_replies, n_round)
    # At round 0, initialize zero-valued correction states.
    if n_round == 0:
        self.init_correction_states(global_model)
    # Check that the input node_ids match known ones.
    if not set(model_params).issubset(self._fds.node_ids()):
        raise FedbiomedAggregatorError(
            "Received updates from nodes that are unknown to this aggregator."
        )
    # Compute the node-wise model update: (x^t - y_i^t).
    model_updates = {
        node_id: {
            key: (global_model[key] - local_value)
            for key, local_value in params.items()
        }
        for node_id, params in model_params.items()
    }
    # Update all Scaffold state variables.
    self.update_correction_states(model_updates, n_updates)
    # Compute and return the aggregated model parameters.
    global_new = {}  # type: Dict[str, Union[torch.Tensor, np.ndarray]]
    for key, val in global_model.items():
        upd = sum(model_updates[node_id][key] for node_id in model_params)
        global_new[key] = val - upd * (self.server_lr / len(model_params))
    return global_new
check_values(n_updates, training_plan)

Check if all values/parameters are correct and have been set before using aggregator.

Raise an error otherwise.

This can prove useful if user has set wrong hyperparameter values, so that user will have errors before performing first round of training

Parameters:

Name Type Description Default
n_updates int

number of updates. Must be non-zero and an integer.

required
training_plan BaseTrainingPlan

training plan. used for checking if optimizer is SGD, otherwise, triggers warning.

required

Raises:

Type Description
FedbiomedAggregatorError

triggered if num_updates entry is missing (needed for Scaffold aggregator)

FedbiomedAggregatorError

triggered if any of the learning rate(s) equals 0

FedbiomedAggregatorError

triggered if number of updates equals 0 or is not an integer

FedbiomedAggregatorError

triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset] has not been set.

Source code in fedbiomed/researcher/aggregators/scaffold.py
def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True:
    """Check if all values/parameters are correct and have been set before using aggregator.

    Raise an error otherwise.

    This can prove useful if user has set wrong hyperparameter values, so that user will
    have errors before performing first round of training

    Args:
        n_updates: number of updates. Must be non-zero and an integer.
        training_plan: training plan. used for checking if optimizer is SGD, otherwise,
            triggers warning.

    Raises:
        FedbiomedAggregatorError: triggered if `num_updates` entry is missing (needed for Scaffold aggregator)
        FedbiomedAggregatorError: triggered if any of the learning rate(s) equals 0
        FedbiomedAggregatorError: triggered if number of updates equals 0 or is not an integer
        FedbiomedAggregatorError: triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset]
            has not been set.
    """
    if n_updates is None:
        raise FedbiomedAggregatorError("Cannot perform Scaffold: missing 'num_updates' entry in the training_args")
    elif n_updates <= 0 or int(n_updates) != float(n_updates):
        raise FedbiomedAggregatorError(
            "n_updates should be a positive non zero integer, but got "
            f"n_updates: {n_updates} in SCAFFOLD aggregator"
        )
    if self._fds is None:
        raise FedbiomedAggregatorError(
            "Federated Dataset not provided, but needed for Scaffold. Please use setter `set_fds()`."
        )
    if hasattr(training_plan, "_optimizer") and training_plan.type() is TrainingPlans.TorchTrainingPlan:
        if not isinstance(training_plan._optimizer, torch.optim.SGD):
            logger.warning(
                f"Found optimizer {training_plan._optimizer}, but SCAFFOLD requieres SGD optimizer."
                "Results may be inconsistants"
            )
    return True
create_aggregator_args(global_model, node_ids)

Return correction states that are to be sent to the nodes.

Parameters:

Name Type Description Default
global_model Dict[str, Union[torch.Tensor, np.ndarray]]

parameters of the global model, formatted as a dict mapping weight tensors to their names.

required
node_ids Collection[str]

identifiers of the nodes that are to receive messages.

required

Returns:

Name Type Description
aggregator_msg Dict[str, Dict[str, Any]]

Dict associating MQTT-transmitted messages to node identifiers.

aggregator_dat Dict[str, Dict[str, Any]]

Dict associating file-exchange-transmitted messages to node identifiers. The Scaffold correction states are part of this dict.

Source code in fedbiomed/researcher/aggregators/scaffold.py
def create_aggregator_args(
    self,
    global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
    node_ids: Collection[str]
) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
    """Return correction states that are to be sent to the nodes.

    Args:
        global_model: parameters of the global model, formatted as a dict
            mapping weight tensors to their names.
        node_ids: identifiers of the nodes that are to receive messages.

    Returns:
        aggregator_msg: Dict associating MQTT-transmitted messages to node
            identifiers.
        aggregator_dat: Dict associating file-exchange-transmitted messages
            to node identifiers. The Scaffold correction states are part of
            this dict.
    """
    # Optionally initialize states, and verify that nodes are known.
    if not self.nodes_deltas:
        self.init_correction_states(global_model)
    if not set(node_ids).issubset(self._fds.node_ids()):
        raise FedbiomedAggregatorError(
            "Scaffold cannot create aggregator args for nodes that are not"
            "covered by its attached FederatedDataset."
        )
    # Pack node-wise messages, for the MQTT and file exchange channels.
    aggregator_msg = {}
    aggregator_dat = {}
    for node_id in node_ids:
        # If a node was late-added to the FederatedDataset, create states.
        if node_id not in self.nodes_deltas:
            zeros = {key: initialize(val)[1] for key, val in self.global_state.items()}
            self.nodes_deltas[node_id] = zeros
            self.nodes_states[node_id] = copy.deepcopy(zeros)
        # Add information for the current node to the message dicts.
        aggregator_dat[node_id] = {
            'aggregator_name': self.aggregator_name,
            'aggregator_correction': self.nodes_deltas[node_id]
        }
        aggregator_msg[node_id] = {
            'aggregator_name': self.aggregator_name
        }
    return aggregator_msg, aggregator_dat
init_correction_states(global_model)

Initialize Scaffold state variables.

Parameters:

Name Type Description Default
global_model Dict[str, Union[torch.Tensor, np.ndarray]]

parameters of the global model, formatted as a dict mapping weight tensors to their names.

required

Raises:

Type Description
FedbiomedAggregatorError

if no FederatedDataset is attached to this aggregator.

Source code in fedbiomed/researcher/aggregators/scaffold.py
def init_correction_states(
    self,
    global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
) -> None:
    """Initialize Scaffold state variables.

    Args:
        global_model: parameters of the global model, formatted as a dict
            mapping weight tensors to their names.

    Raises:
        FedbiomedAggregatorError: if no FederatedDataset is attached to
            this aggregator.
    """
    # Gather node ids from the attached FederatedDataset.
    if self._fds is None:
        raise FedbiomedAggregatorError(
            "Cannot initialize correction states: Scaffold aggregator does "
            "not have a FederatedDataset attached."
        )
    node_ids = self._fds.node_ids()
    # Initialize nodes states with zero scalars, that will be summed into actual tensors.
    init_params = {key: initialize(tensor)[1] for key, tensor in global_model.items()}
    self.nodes_deltas = {node_id: copy.deepcopy(init_params) for node_id in node_ids}
    self.nodes_states = copy.deepcopy(self.nodes_deltas)
    self.global_state = init_params
load_state(state=None)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def load_state(self, state: Dict[str, Any] = None):
    super().load_state(state)

    self.server_lr = self._aggregator_args['server_lr']

    # loading global state
    global_state_filename = self._aggregator_args['global_state_filename']
    self.global_state = Serializer.load(global_state_filename)

    for node_id in self._aggregator_args['aggregator_correction']:
        arg_filename = self._aggregator_args['aggregator_correction'][node_id]
        self.nodes_deltas[node_id] = Serializer.load(arg_filename)
save_state(breakpoint_path, global_model)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def save_state(
    self,
    breakpoint_path: str,
    global_model: Mapping[str, Union[torch.Tensor, np.ndarray]]
) -> Dict[str, Any]:
    # adding aggregator parameters to the breakpoint that wont be sent to nodes
    self._aggregator_args['server_lr'] = self.server_lr

    # saving global state variable into a file
    filename = os.path.join(breakpoint_path, f"global_state_{uuid.uuid4()}.mpk")
    Serializer.dump(self.global_state, filename)
    self._aggregator_args['global_state_filename'] = filename
    # adding aggregator parameters that will be sent to nodes afterwards
    return super().save_state(
        breakpoint_path, global_model=global_model, node_ids=self._fds.node_ids()
    )
set_nodes_learning_rate_after_training(training_plan, training_replies, n_round)

Gets back learning rate of optimizer from Node (if learning rate scheduler is used)

Parameters:

Name Type Description Default
training_plan BaseTrainingPlan

training plan instance

required
training_replies List[Responses]

training replies that must contain am optimizer_args entry and a learning rate

required
n_round int

number of rounds already performed

required

Raises:

Type Description
FedbiomedAggregatorError

raised when setting learning rate has been unsuccessful

Returns:

Type Description
Dict[str, List[float]]

Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate).

Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_nodes_learning_rate_after_training(
    self,
    training_plan: BaseTrainingPlan,
    training_replies: Responses,
    n_round: int
) -> Dict[str, List[float]]:
    """Gets back learning rate of optimizer from Node (if learning rate scheduler is used)

    Args:
        training_plan (BaseTrainingPlan): training plan instance
        training_replies (List[Responses]): training replies that must contain am `optimizer_args`
            entry and a learning rate
        n_round (int): number of rounds already performed

    Raises:
        FedbiomedAggregatorError: raised when setting learning rate has been unsuccessful

    Returns:
        Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as
            the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate).
    """
    # to be implemented in a utDict[str, Union[np.ndarray, torch.Tensor]]ils module (for pytorch optimizers)

    n_model_layers = len(training_plan.get_model_params())
    for node_id in self._fds.node_ids():
        lrs: List[float] = []

        if training_replies[n_round].get_index_from_node_id(node_id) is not None:
            # get updated learning rate if provided...
            node_idx: int = training_replies[n_round].get_index_from_node_id(node_id)
            lrs += training_replies[n_round][node_idx]['optimizer_args'].get('lr')

        else:
            # ...otherwise retrieve default learning rate
            lrs += training_plan.get_learning_rate()

        if len(lrs) == 1:
            # case where there is one learning rate
            lr = lrs * n_model_layers

        elif len(lrs) == n_model_layers:
            # case where there are several learning rates value
            lr = lrs

        else:
            raise FedbiomedAggregatorError(
                "Error when setting node learning rate for SCAFFOLD: cannot extract node learning rate."
            )

        self.nodes_lr[node_id] = lr
    return self.nodes_lr
set_training_plan_type(training_plan_type)

Overrides set_training_plan_type from parent class. Checks the training plan type, and if it is SKlearnTrainingPlan, raises an error. Otherwise, calls parent method.

Parameters:

Name Type Description Default
training_plan_type TrainingPlans

training_plan type

required

Raises:

Type Description
FedbiomedAggregatorError

raised if training_plan type has been set to SKLearn training plan

Returns:

Name Type Description
TrainingPlans TrainingPlans

training plan type

Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
    """
    Overrides `set_training_plan_type` from parent class.
    Checks the training plan type, and if it is SKlearnTrainingPlan,
    raises an error. Otherwise, calls parent method.

    Args:
        training_plan_type (TrainingPlans): training_plan type

    Raises:
        FedbiomedAggregatorError: raised if training_plan type has been set to SKLearn training plan

    Returns:
        TrainingPlans: training plan type
    """
    if training_plan_type == TrainingPlans.SkLearnTrainingPlan:
        raise FedbiomedAggregatorError("Aggregator SCAFFOLD not implemented for SKlearn")
    training_plan_type = super().set_training_plan_type(training_plan_type)

    # TODO: trigger a warning if user is trying to use scaffold with something else than SGD
    return training_plan_type
update_correction_states(model_updates, n_updates)

Update all Scaffold state variables based on node-wise model updates.

Performed computations:
  • Update participating nodes' state:
    • c_i = delta_i + 1/(K*eta_i) * update_i
  • Update the global state and all nodes' correction state:
    • c = 1/N sum_{i=1}^n c_i
    • delta_i = (c_i - c)

Parameters:

Name Type Description Default
model_updates Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]]

node-wise model weight updates.

required
n_updates int

number of local optimization steps.

required
Source code in fedbiomed/researcher/aggregators/scaffold.py
def update_correction_states(
    self,
    model_updates: Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]],
    n_updates: int,
) -> None:
    """Update all Scaffold state variables based on node-wise model updates.

    Performed computations:
    ----------------------

    - Update participating nodes' state:
        * c_i = delta_i + 1/(K*eta_i) * update_i
    - Update the global state and all nodes' correction state:
        * c = 1/N sum_{i=1}^n c_i
        * delta_i = (c_i - c)

    Args:
        model_updates: node-wise model weight updates.
        n_updates: number of local optimization steps.
    """
    # Update the node-wise states for participating nodes:
    # c_i^{t+1} = delta_i^t + (x^t - y_i^t) / (M * eta)
    for node_id, updates in model_updates.items():
        d_i = self.nodes_deltas[node_id]
        self.nodes_states[node_id] = {
            key: d_i[key] + val / (self.nodes_lr[node_id][idx] * n_updates)
            for idx, (key, val) in enumerate(updates.items())
        }
    # Update the global state: c^{t+1} = average(c_i^{t+1})
    self.global_state = {
        key: (
            sum(state[key] for state in self.nodes_states.values())
            / len(self.nodes_states)
        )
        for key in self.global_state
    }
    # Compute the new node-wise correction states:
    # delta_i^{t+1} = c_i^{t+1} - c^{t+1}
    self.nodes_deltas = {
        node_id: {
            key: val - self.global_state[key] for key, val in state.items()
        }
        for node_id, state in self.nodes_states.items()
    }

Functions

federated_averaging(model_params, weights)

Defines Federated Averaging (FedAvg) strategy for model aggregation.

Parameters:

Name Type Description Default
model_params List[Dict[str, Union[torch.Tensor, np.ndarray]]]

list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights)

required
weights List[float]

weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node). Items in the list must always sum up to 1

required

Returns:

Type Description
Mapping[str, Union[torch.Tensor, np.ndarray]]

Final model with aggregated layers, as an OrderedDict object.

Source code in fedbiomed/researcher/aggregators/functional.py
def federated_averaging(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
                        weights: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
    """Defines Federated Averaging (FedAvg) strategy for model aggregation.

    Args:
        model_params: list that contains nodes' model parameters; each model is stored as an OrderedDict (maps
            model layer name to the model weights)
        weights: weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node).
            Items in the list must always sum up to 1

    Returns:
        Final model with aggregated layers, as an OrderedDict object.
    """
    assert len(model_params) > 0, 'An empty list of models was passed.'
    assert len(weights) == len(model_params), 'List with number of observations must have ' \
                                              'the same number of elements that list of models.'

    # Compute proportions
    proportions = [n_k / sum(weights) for n_k in weights]
    return weighted_sum(model_params, proportions)

initialize(val)

Initialize tensor or array vector.

Source code in fedbiomed/researcher/aggregators/functional.py
def initialize(val: Union[torch.Tensor, np.ndarray]) -> Tuple[str, Union[torch.Tensor, np.ndarray]]:
    """Initialize tensor or array vector. """
    if isinstance(val, torch.Tensor):
        return 'tensor', torch.zeros_like(val).float()

    if isinstance(val, (list, np.ndarray)):
        val = np.array(val)
        return 'array', np.zeros(val.shape, dtype = float)

weighted_sum(model_params, proportions)

Performs weighted sum operation

Parameters:

Name Type Description Default
model_params List[Dict[str, Union[torch.Tensor, np.ndarray]]]

list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights)

required
proportions List[float]

weights of all items whithin model_params's list

required

Returns:

Type Description
Mapping[str, Union[torch.Tensor, np.ndarray]]

Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum operation

Source code in fedbiomed/researcher/aggregators/functional.py
def weighted_sum(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
                 proportions: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
    """Performs weighted sum operation

    Args:
        model_params (List[Dict[str, Union[torch.Tensor, np.ndarray]]]): list that contains nodes'
            model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights)
        proportions (List[float]): weights of all items whithin model_params's list

    Returns:
        Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum 
                                                       operation
    """
    # Empty model parameter dictionary
    avg_params = copy.deepcopy(model_params[0])

    for key, val in avg_params.items():
        (t, avg_params[key] ) = initialize(val)

    if t == 'tensor':
        for model, weight in zip(model_params, proportions):
            for key in avg_params.keys():
                avg_params[key] += weight * model[key]

    if t == 'array':
        for key in avg_params.keys():
            matr = np.array([ d[key] for d in model_params ])
            avg_params[key] = np.average(matr, weights=np.array(proportions), axis=0)

    return avg_params