fedbiomed.researcher.aggregators
Module:fedbiomed.researcher.aggregators
Classes
Aggregator
Aggregator()
Defines methods for aggregating strategy (eg FedAvg, FedProx, SCAFFOLD, ...).
Source code in fedbiomed/researcher/aggregators/aggregator.py
def __init__(self):
self._aggregator_args: dict = None
self._fds: FederatedDataSet = None
self._training_plan_type: TrainingPlans = None
Functions
aggregate(model_params, weights, args, kwargs)
aggregate(model_params, weights, args, kwargs)
Strategy to aggregate models
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | list | List of model parameters received from each node | required |
weights | list | Weight for each node-model-parameter set | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | If the method is not defined by inheritor |
Source code in fedbiomed/researcher/aggregators/aggregator.py
def aggregate(self, model_params: list, weights: list, *args, **kwargs) -> Dict:
"""
Strategy to aggregate models
Args:
model_params: List of model parameters received from each node
weights: Weight for each node-model-parameter set
Raises:
FedbiomedAggregatorError: If the method is not defined by inheritor
"""
msg = ErrorNumbers.FB401.value + \
": aggreate method should be overloaded by the choosen strategy"
logger.critical(msg)
raise FedbiomedAggregatorError(msg)
check_values(args, kwargs)
check_values(args, kwargs)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def check_values(self, *args, **kwargs) -> True:
return True
create_aggregator_args(args, kwargs)
create_aggregator_args(args, kwargs)
Returns aggregator arguments that are expecting by the nodes
contains Aggregator
parameters that will be sent through MQTT message
service
contains parameters that will be sent through file exchange message.
Both dictionaries are mapping node_id to 'Aggregator` parameters specific to each Node.
Source code in fedbiomed/researcher/aggregators/aggregator.py
def create_aggregator_args(self, *args, **kwargs) -> Tuple[dict, dict]:
"""Returns aggregator arguments that are expecting by the nodes
Returns:
dict: contains `Aggregator` parameters that will be sent through MQTT message
service
dict: contains parameters that will be sent through file exchange message.
Both dictionaries are mapping node_id to 'Aggregator` parameters specific
to each Node.
"""
return self._aggregator_args or {}, {}
load_state(state=None, kwargs)
load_state(state=None, kwargs)
use for breakpoints. load the aggregator state
Source code in fedbiomed/researcher/aggregators/aggregator.py
def load_state(self, state: Dict[str, Any] = None, **kwargs):
"""
use for breakpoints. load the aggregator state
"""
self._aggregator_args = state['parameters']
normalize_weights(weights)
staticmethod
normalize_weights(weights)
Load list of weights assigned to each node and normalize these weights so they sum up to 1
assuming that all values are >= 0.0
Source code in fedbiomed/researcher/aggregators/aggregator.py
@staticmethod
def normalize_weights(weights: list) -> list:
"""
Load list of weights assigned to each node and
normalize these weights so they sum up to 1
assuming that all values are >= 0.0
"""
_l = len(weights)
if _l == 0:
return []
_s = sum(weights)
if _s == 0:
norm = [ 1.0 / _l ] * _l
else:
norm = [_w / _s for _w in weights]
return norm
save_state(training_plan=None, breakpoint_path=None, aggregator_args_create)
save_state(training_plan=None, breakpoint_path=None, aggregator_args_create)
use for breakpoints. save the aggregator state
Source code in fedbiomed/researcher/aggregators/aggregator.py
def save_state(self,
training_plan: Optional[BaseTrainingPlan] = None,
breakpoint_path: Optional[str] = None,
**aggregator_args_create) -> Dict[str, Any]:
"""
use for breakpoints. save the aggregator state
"""
aggregator_args_thr_msg, aggregator_args_thr_files = self.create_aggregator_args(**aggregator_args_create)
if aggregator_args_thr_msg:
if self._aggregator_args is None:
self._aggregator_args = {}
self._aggregator_args.update(aggregator_args_thr_msg)
#aggregator_args = copy.deepcopy(self._aggregator_args)
if breakpoint_path is not None and aggregator_args_thr_files:
for node_id, node_arg in aggregator_args_thr_files.items():
if isinstance(node_arg, dict):
for arg_name, aggregator_arg in node_arg.items():
if arg_name != 'aggregator_name': # do not save `aggregator_name` as a file
filename = self._save_arg_to_file(training_plan, breakpoint_path,
arg_name, node_id, aggregator_arg)
self._aggregator_args.setdefault(arg_name, {})
self._aggregator_args[arg_name][node_id] = filename # replacing value by a path towards a file
else:
filename = self._save_arg_to_file(training_plan, breakpoint_path, arg_name, node_id, node_arg)
self._aggregator_args[arg_name] = filename
state = {
"class": type(self).__name__,
"module": self.__module__,
"parameters": self._aggregator_args
}
return state
set_fds(fds)
set_fds(fds)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_fds(self, fds: FederatedDataSet) -> FederatedDataSet:
self._fds = fds
return self._fds
set_training_plan_type(training_plan_type)
set_training_plan_type(training_plan_type)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
self._training_plan_type = training_plan_type
return self._training_plan_type
FedAverage
FedAverage()
Bases: Aggregator
Defines the Federated averaging strategy
Source code in fedbiomed/researcher/aggregators/fedavg.py
def __init__(self):
"""Construct `FedAverage` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
"""
super(FedAverage, self).__init__()
self.aggregator_name = "FedAverage"
Attributes
aggregator_name instance-attribute
aggregator_name = 'FedAverage'
Functions
aggregate(model_params, weights, args, kwargs)
aggregate(model_params, weights, args, kwargs)
Aggregates local models sent by participating nodes into a global model, following Federated Averaging strategy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | list | contains each model layers | required |
weights | list | contains all weights of a given layer. | required |
Returns:
Type | Description |
---|---|
Dict | Aggregated parameters |
Source code in fedbiomed/researcher/aggregators/fedavg.py
def aggregate(self, model_params: list, weights: list, *args, **kwargs) -> Dict:
""" Aggregates local models sent by participating nodes into a global model, following Federated Averaging
strategy.
Args:
model_params: contains each model layers
weights: contains all weights of a given layer.
Returns:
Aggregated parameters
"""
model_params_processed = [list(model_param.values())[0] for model_param in model_params] # model params are contained in a dictionary with node_id as key, we just retrieve the params
weights_processed = [weight if isinstance(weight, float) else list(weight.values())[0] for weight in weights]
weights_processed = self.normalize_weights(weights_processed)
return federated_averaging(model_params_processed, weights_processed)
Scaffold
Scaffold(server_lr=1.0, fds=None)
Bases: Aggregator
Defines the Scaffold strategy
Despite being an algorithm of choice for federated learning, it is observed that FedAvg suffers from client-drift
when the data is heterogeneous (non-iid), resulting in unstable and slow convergence. SCAFFOLD uses control variates (variance reduction) to correct for the client-drift
in its local updates. Intuitively, SCAFFOLD estimates the update direction for the server model (c) and the update direction for each client (c_i). The difference (c - c_i) is then an estimate of the client-drift which is used to correct the local update.
Fed-BioMed implementation details
Our implementation is heavily influenced by our design choice to prevent storing any state on the nodes between FL rounds. In particular, this means that the computation of the control variates (i.e. the correction states) needs to be performed centrally by the aggregator. Roughly, our implementation follows these steps (following the notation of the original Scaffold paper):
- let \(\delta_i = \mathbf{c} - \mathbf{c}_i \)
- foreach(round):
- sample \( S \) nodes participating in this round out of \( N \) total
- the server communicates the global model \( \mathbf{x} \) and the correction states \( \delta_i \) to all clients
- parallel on each client
- initialize local model \( \mathbf{y}_i = \mathbf{x} \)
- foreach(update) until K updates have been performed
- obtain a data batch
- compute the gradients for this batch \( g(\mathbf{y}_i) \)
- add correction term to gradients \( g(\mathbf{y}_i) += \delta_i \)
- update model with one optimizer step \( \mathbf{y}_i += \eta_i g(\mathbf{y}_i) \)
- end foreach(update)
- communicate updated model \( \mathbf{y}_i \) and learning rate \( \eta_i \)
- end parallel section on each client
- the server computes the node-wise average of corrected gradients \( \mathbf{ACG}_i = (\mathbf{x} - \mathbf{y}_i)/(\eta_iK) \)
- the server updates the global correction term \( \mathbf{c} = (1 - S/N)\mathbf{c} + 1/N\sum_{i \in S}\mathbf{ACG}_i \)
- the server updates the correction states of each client \(\delta_i = \mathbf{ACG}_i - \mathbf{c} - \delta_i \)
- the server updates the global model by average \( \mathbf{x} = (1-\eta)\mathbf{x} + \eta/S\sum_{i \in S} \mathbf{y}_i \)
- end foreach(round)
References:
- Scaffold: Stochastic Controlled Averaging for Federated Learning
- TCT: Convexifying Federated Learning using Bootstrapped Neural Tangent Kernels
Attributes:
Name | Type | Description |
---|---|---|
aggregator_name | str | name of the aggregator |
server_lr | float | value of the server learning rate |
nodes_correction_states | Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] | a nested dictionary of correction parameters obtained for each client, in the format {node id: node-wise corrections}. The node-wise corrections are a dictionary in the format {parameter name: correction value} where the model parameters are those contained in each node's model.named_parameters(). |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_lr | float | server's (or Researcher's) learning rate. Defaults to 1.. | 1.0 |
fds | FederatedDataset | FederatedDataset obtained after a | None |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None):
"""Constructs `Scaffold` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
Args:
server_lr (float): server's (or Researcher's) learning rate. Defaults to 1..
fds (FederatedDataset, optional): FederatedDataset obtained after a `search` request. Defaults to None.
"""
super().__init__()
self.aggregator_name: str = "Scaffold"
if server_lr == 0.:
raise FedbiomedAggregatorError("SCAFFOLD Error: Server learning rate cannot be equal to 0")
self.server_lr: float = server_lr
self.nodes_correction_states: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {}
self.global_state: Mapping[str, Union[torch.Tensor, np.ndarray]] = {}
self.nodes_lr: Dict[str, List[float]] = {}
if fds is not None:
self.set_fds(fds)
self._aggregator_args = {} # we need `_aggregator_args` to be not None
Attributes
aggregator_name instance-attribute
aggregator_name: str = 'Scaffold'
global_state instance-attribute
global_state: Mapping[
str, Union[torch.Tensor, np.ndarray]
] = {}
nodes_correction_states instance-attribute
nodes_correction_states: Dict[
str, Mapping[str, Union[torch.Tensor, np.ndarray]]
] = {}
nodes_lr instance-attribute
nodes_lr: Dict[str, List[float]] = {}
server_lr instance-attribute
server_lr: float = server_lr
Functions
aggregate(model_params, weights, global_model, training_plan, training_replies, node_ids, n_updates=1, n_round=0, args, kwargs)
aggregate(model_params, weights, global_model, training_plan, training_replies, node_ids, n_updates=1, n_round=0, args, kwargs)
Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option) [Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]
Performed computations:
c_i(+) <- c_i - c + 1/(K*eta_l)(x - y_i) c <- c + 1/N * sum_S(c_i(+) - c_i)
x <- x + eta_g/S * sum_S(y_i - x)
where, according to paper notations c_i: correction state for node i
; c: correction state at the beginning of round eta_g: server's learning rate eta_l: nodes learning rate (may be different from one node to another) N: total number of node participating to federated learning S: number of nodes considered during current round (S<=N) K: number of updates done during the round (ie number of data batches). x: global model parameters y_i: node i 's local model parameters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | list | list of models parameters recieved from nodes | required |
weights | List[Dict[str, float]] | weights depciting sample proportions available on each node. Unused for Scaffold. | required |
global_model | Mapping[str, Union[torch.Tensor, np.ndarray]] | global model, ie aggregated model | required |
training_plan | BaseTrainingPlan | instance of TrainingPlan | required |
node_ids | Iterable[str] | iterable containing node_id (string) participating in the current round. | required |
n_updates | int | number of updates (number of batch performed). Defaults to 1. | 1 |
n_round | int | current round. Defaults to 0. | 0 |
Returns:
Name | Type | Description |
---|---|---|
Dict | Dict | aggregated parameters, ie mapping of layer names and layer values. |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def aggregate(self,
model_params: list,
weights: List[Dict[str, float]],
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]],
training_plan: BaseTrainingPlan,
training_replies: Responses,
node_ids: Iterable[str],
n_updates: int = 1,
n_round: int = 0,
*args, **kwargs) -> Dict:
"""
Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option)
[Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]
Performed computations:
-----------------------
c_i(+) <- c_i - c + 1/(K*eta_l)(x - y_i)
c <- c + 1/N * sum_S(c_i(+) - c_i)
x <- x + eta_g/S * sum_S(y_i - x)
where, according to paper notations
c_i: correction state for node `i`;
c: correction state at the beginning of round
eta_g: server's learning rate
eta_l: nodes learning rate (may be different from one node to another)
N: total number of node participating to federated learning
S: number of nodes considered during current round (S<=N)
K: number of updates done during the round (ie number of data batches).
x: global model parameters
y_i: node i 's local model parameters
Args:
model_params (list): list of models parameters recieved from nodes
weights (List[Dict[str, float]]): weights depciting sample proportions available
on each node. Unused for Scaffold.
global_model (Mapping[str, Union[torch.Tensor, np.ndarray]]): global model,
ie aggregated model
training_plan (BaseTrainingPlan): instance of TrainingPlan
node_ids (Iterable[str]): iterable containing node_id (string) participating in the current round.
n_updates (int, optional): number of updates (number of batch performed). Defaults to 1.
n_round (int, optional): current round. Defaults to 0.
Returns:
Dict: aggregated parameters, ie mapping of layer names and layer values.
"""
# Gather the learning rates used by nodes, updating `self.nodes_lr`.
self.set_nodes_learning_rate_after_training(training_plan, training_replies, n_round)
# Unpack input local model parameters to {node_id: {name: value, ...}, ...} format.
model_params = {list(node_content.keys())[0]: list(node_content.values())[0] for node_content in model_params}
# Compute the new aggregated model parameters.
aggregated_parameters = self.scaling(model_params, global_model)
# At round 0, initialize zero-valued correction states.
if n_round == 0:
self.init_correction_states(global_model, node_ids)
# Update correction states.
self.update_correction_states(model_params, global_model, n_updates)
# Return aggregated parameters.
return aggregated_parameters
check_values(n_updates, training_plan)
check_values(n_updates, training_plan)
This method checks if all values/parameters are correct and have been set before using aggregator. Raises error otherwise This can prove useful if user has set wrong hyperparameter values, so that user will have errors before performing first round of training
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_updates | int | number of updates. Must be non-zero and an integer. | required |
training_plan | BaseTrainingPlan | training plan. used for checking if optimizer is SGD, otherwise, triggers warning. | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | triggered if |
FedbiomedAggregatorError | triggered if any of the learning rate(s) equals 0 |
FedbiomedAggregatorError | triggered if number of updates equals 0 or is not an integer |
FedbiomedAggregatorError | triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset] has not been set. |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True:
"""
This method checks if all values/parameters are correct and have been set before using aggregator.
Raises error otherwise
This can prove useful if user has set wrong hyperparameter values, so that user will
have errors before performing first round of training
Args:
n_updates (int): number of updates. Must be non-zero and an integer.
training_plan (BaseTrainingPlan): training plan. used for checking if optimizer is SGD, otherwise,
triggers warning.
Raises:
FedbiomedAggregatorError: triggered if `num_updates` entry is missing (needed for Scaffold aggregator)
FedbiomedAggregatorError: triggered if any of the learning rate(s) equals 0
FedbiomedAggregatorError: triggered if number of updates equals 0 or is not an integer
FedbiomedAggregatorError: triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset]
has not been set.
"""
if n_updates is None:
raise FedbiomedAggregatorError("Cannot perform Scaffold: missing 'num_updates' entry in the training_args")
elif n_updates <= 0 or int(n_updates) != float(n_updates):
raise FedbiomedAggregatorError(f"n_updates should be a positive non zero integer, but got n_updates: {n_updates} in SCAFFOLD aggregator")
if self._fds is None:
raise FedbiomedAggregatorError(" Federated Dataset not provided, but needed for Scaffold. Please use setter `set_fds()`")
if hasattr(training_plan, "_optimizer") and training_plan.type() is TrainingPlans.TorchTrainingPlan:
if not isinstance(training_plan._optimizer, torch.optim.SGD):
logger.warning(f"Found optimizer {training_plan._optimizer}, but SCAFFOLD requieres SGD optimizer. Results may be inconsistants")
return True
create_aggregator_args(global_model, node_ids)
create_aggregator_args(global_model, node_ids)
Sends additional arguments for aggregator. For scaffold, it is mainly correction states
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_model | Mapping[str, Union[torch.Tensor, np.ndarray]] | aggregated model | required |
node_ids | Iterator[str] | iterable that contains strings of nodes id that have participated in the round | required |
Returns:
Type | Description |
---|---|
Tuple[Dict, Dict] | Tuple[Dict, Dict]: first dictionary contains parameters that will be sent through MQTT message service, second dictionary parameters that will be sent through file exchange message. Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to each |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def create_aggregator_args(self,
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]],
node_ids: Iterator[str]) -> Tuple[Dict, Dict]:
"""Sends additional arguments for aggregator. For scaffold, it is mainly correction states
Args:
global_model (Mapping[str, Union[torch.Tensor, np.ndarray]]): aggregated model
node_ids (Iterator[str]): iterable that contains strings of nodes id that have participated in
the round
Returns:
Tuple[Dict, Dict]: first dictionary contains parameters that will be sent through MQTT message
service, second dictionary parameters that will be sent through file exchange message.
Aggregators args are dictionary mapping node_id to SCAFFOLD parameters specific to
each `Nodes`.
"""
if not self.nodes_correction_states:
self.init_correction_states(global_model, node_ids)
aggregator_args_thr_msg, aggregator_args_thr_file = {}, {}
for node_id in node_ids:
# in case of a new node, initialize its correction state
if node_id not in self.nodes_correction_states:
self.nodes_correction_states[node_id] = {
key: copy.deepcopy(initialize(tensor))[1] for key, tensor in global_model.items()
}
# pack information and parameters to send
aggregator_args_thr_file[node_id] = {
'aggregator_name': self.aggregator_name,
'aggregator_correction': self.nodes_correction_states[node_id]
}
aggregator_args_thr_msg[node_id] = {
'aggregator_name': self.aggregator_name
}
return aggregator_args_thr_msg, aggregator_args_thr_file
init_correction_states(global_model, node_ids)
init_correction_states(global_model, node_ids)
Initialises correction_states variable for Scaffold
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_model | Mapping[str, Union[torch.Tensor, np.ndarray]] | global model mapping layer name to model | required |
node_ids | Iterable[str] | iterable containing node_ids | required |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def init_correction_states(self,
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]],
node_ids: Iterable[str],
):
"""Initialises correction_states variable for Scaffold
Args:
global_model (Mapping[str, Union[torch.Tensor, np.ndarray]]): global model mapping layer name to model
parameters
node_ids (Iterable[str]): iterable containing node_ids
"""
# initialize nodes states with zeros tensors
init_params = {key: initialize(tensor)[1] for key, tensor in global_model.items()}
self.nodes_correction_states = {node_id: copy.deepcopy(init_params) for node_id in node_ids}
self.global_state = init_params
load_state(state=None, training_plan=None)
load_state(state=None, training_plan=None)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def load_state(self, state: Dict[str, Any] = None, training_plan: BaseTrainingPlan = None):
super().load_state(state)
self.server_lr = self._aggregator_args['server_lr']
# loading global state
global_state_filename = self._aggregator_args['global_state_filename']
self.global_state = training_plan.load(global_state_filename, to_params=True)
for node_id in self._aggregator_args['aggregator_correction'].keys():
arg_filename = self._aggregator_args['aggregator_correction'][node_id]
self.nodes_correction_states[node_id] = training_plan.load(arg_filename)
save_state(training_plan, breakpoint_path, global_model)
save_state(training_plan, breakpoint_path, global_model)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def save_state(self, training_plan: BaseTrainingPlan,
breakpoint_path: str,
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]]) -> Dict[str, Any]:
# adding aggregator parameters to the breakpoint that wont be sent to nodes
self._aggregator_args['server_lr'] = self.server_lr
# saving global state variable into a file
filename = os.path.join(breakpoint_path, 'global_state_' + str(uuid.uuid4()) + '.pt')
training_plan.save(filename, self.global_state)
self._aggregator_args['global_state_filename'] = filename
# adding aggregator parameters that will be sent to nodes afterwards
return super().save_state(training_plan,
breakpoint_path,
global_model=global_model,
node_ids=self._fds.node_ids())
scaling(model_params, global_model)
scaling(model_params, global_model)
Computes the aggregated model.
Let
- x = the global model from the previous aggregation round
- y_i = the local model after training for the i^th node
- eta_g = the global learning rate
Then this function computes the quantity x (1 - eta_g) + eta_g / S * sum_i(y_i))
Proof
x <- x + eta_g * grad(x) x <- x + eta_g / S * sum_i(y_i - x) x <- x (1 - eta_g) + eta_g / S * sum_i(y_i)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | Dict[str, Mapping[str, Union[np.ndarray, torch.Tensor]]] | dictionary of model parameters obtained after one round of federated training, in the format {node id: {parameter name: parameter value}}. | required |
global_model | Mapping[str, Union[np.ndarray, torch.Tensor]] | dictionary representing the previous iteration of the global model, in the format {parameter name: parameter value}. This corresponds to \(\mathbf{x}\) in the notation of the scaffold paper. | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[np.ndarray, torch.Tensor]] | A dictionary of aggregated parameters, in the format {parameter name: parameter value}, where the parameter names are the same as those of the input global models |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def scaling(self,
model_params: Dict[str, Mapping[str, Union[np.ndarray, torch.Tensor]]],
global_model: Mapping[str, Union[np.ndarray, torch.Tensor]]
) -> Mapping[str, Union[np.ndarray, torch.Tensor]]:
"""Computes the aggregated model.
Let
- x = the global model from the previous aggregation round
- y_i = the local model after training for the i^th node
- eta_g = the global learning rate
Then this function computes the quantity `x (1 - eta_g) + eta_g / S * sum_i(y_i))`
Proof:
x <- x + eta_g * grad(x)
x <- x + eta_g / S * sum_i(y_i - x)
x <- x (1 - eta_g) + eta_g / S * sum_i(y_i)
Args:
model_params: dictionary of model parameters obtained after one round of federated training,
in the format {node id: {parameter name: parameter value}}.
global_model: dictionary representing the previous iteration of the global model,
in the format {parameter name: parameter value}. This corresponds to $\mathbf{x}$ in the notation
of the scaffold paper.
Returns:
A dictionary of aggregated parameters, in the format {parameter name: parameter value}, where the
parameter names are the same as those of the input global models
"""
aggregated_parameters = {}
for key, val in global_model.items():
update = sum(params[key] for params in model_params.values()) / len(model_params)
newval = (1 - self.server_lr) * val + self.server_lr * update
aggregated_parameters[key] = newval
return aggregated_parameters
set_nodes_learning_rate_after_training(training_plan, training_replies, n_round)
set_nodes_learning_rate_after_training(training_plan, training_replies, n_round)
Gets back learning rate of optimizer from Node (if learning rate scheduler is used)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training_plan | BaseTrainingPlan | training plan instance | required |
training_replies | List[Responses] | training replies that must contain am | required |
n_round | int | number of rounds already performed | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | raised when setting learning rate has been unsuccessful |
Returns:
Type | Description |
---|---|
Dict[str, List[float]] | Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate). |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_nodes_learning_rate_after_training(self, training_plan: BaseTrainingPlan,
training_replies: List[Responses],
n_round: int) -> Dict[str, List[float]]:
"""Gets back learning rate of optimizer from Node (if learning rate scheduler is used)
Args:
training_plan (BaseTrainingPlan): training plan instance
training_replies (List[Responses]): training replies that must contain am `optimizer_args`
entry and a learning rate
n_round (int): number of rounds already performed
Raises:
FedbiomedAggregatorError: raised when setting learning rate has been unsuccessful
Returns:
Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as
the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate).
"""
# to be implemented in a utMapping[str, Union[np.ndarray, torch.Tensor]]ils module (for pytorch optimizers)
n_model_layers = len(training_plan.get_model_params())
for node_id in self._fds.node_ids():
lrs: List[float] = []
if training_replies[n_round].get_index_from_node_id(node_id) is not None:
# get updated learning rate if provided...
node_idx: int = training_replies[n_round].get_index_from_node_id(node_id)
lrs += training_replies[n_round][node_idx]['optimizer_args'].get('lr')
else:
# ...otherwise retrieve default learning rate
lrs += training_plan.get_learning_rate()
if len(lrs) == 1:
# case where there is one learning rate
lr = lrs * n_model_layers
elif len(lrs) == n_model_layers:
# case where there are several learning rates value
lr = lrs
else:
raise FedbiomedAggregatorError("Error when setting node learning rate for SCAFFOLD: cannot extract node learning rate.")
self.nodes_lr[node_id] = lr
return self.nodes_lr
set_training_plan_type(training_plan_type)
set_training_plan_type(training_plan_type)
Overrides set_training_plan_type
from parent class. Checks the training plan type, and if it is SKlearnTrainingPlan, raises an error. Otherwise, calls parent method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training_plan_type | TrainingPlans | training_plan type | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | raised if training_plan type has been set to SKLearn training plan |
Returns:
Name | Type | Description |
---|---|---|
TrainingPlans | TrainingPlans | training plan type |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
"""
Overrides `set_training_plan_type` from parent class.
Checks the training plan type, and if it is SKlearnTrainingPlan,
raises an error. Otherwise, calls parent method.
Args:
training_plan_type (TrainingPlans): training_plan type
Raises:
FedbiomedAggregatorError: raised if training_plan type has been set to SKLearn training plan
Returns:
TrainingPlans: training plan type
"""
if training_plan_type == TrainingPlans.SkLearnTrainingPlan:
raise FedbiomedAggregatorError("Aggregator SCAFFOLD not implemented for SKlearn")
training_plan_type = super().set_training_plan_type(training_plan_type)
# TODO: trigger a warning if user is trying to use scaffold with something else than SGD
return training_plan_type
update_correction_states(local_models, global_model, n_updates=1)
update_correction_states(local_models, global_model, n_updates=1)
Updates correction states
Proof:
c <- c + S/N grad(c) c <- c + 1/N sum_i(c_i(+) - c_i) c <- c + 1/N * sum_i( 1/ (K * eta_l)(x - y_i) - c) c <- (1 - S/N) c + ACG_i , where ACG_i = sum_i( 1/ (K * eta_l)(x - y_i))
where (according to Scaffold paper): c: is the correction term S: the number of nodes participating in the current round N: the total number of node participating in the experiment K: number of updates eta_l: nodes' learning rate x: global model before updates y_i: local model updates
Remark: c^{t=0} = 0
Parameters:
Name | Type | Description | Default |
---|---|---|---|
local_models | Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] | Node-wise local model parameters after updates, as as {name: value} parameters mappings indexed by node id. | required |
global_model | Mapping[str, Union[torch.Tensor, np.ndarray]] | Global model parameters (before updates), as a single {name: value} parameters mapping. | required |
n_updates | int | number of batches (or updates) performed during one round Referred to as | 1 |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | if no FederatedDataset has been found. |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def update_correction_states(self,
local_models: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]],
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]],
n_updates: int = 1,) -> None:
"""Updates correction states
Proof:
c <- c + S/N grad(c)
c <- c + 1/N sum_i(c_i(+) - c_i)
c <- c + 1/N * sum_i( 1/ (K * eta_l)(x - y_i) - c)
c <- (1 - S/N) c + ACG_i , where ACG_i = sum_i( 1/ (K * eta_l)(x - y_i))
where (according to Scaffold paper):
c: is the correction term
S: the number of nodes participating in the current round
N: the total number of node participating in the experiment
K: number of updates
eta_l: nodes' learning rate
x: global model before updates
y_i: local model updates
Remark:
c^{t=0} = 0
Args:
local_models: Node-wise local model parameters after updates, as
as {name: value} parameters mappings indexed by node id.
global_model: Global model parameters (before updates), as a single
{name: value} parameters mapping.
n_updates: number of batches (or updates) performed during one round
Referred to as `K` in the Scaffold paper. Defaults to 1.
Raises:
FedbiomedAggregatorError: if no FederatedDataset has been found.
"""
# Gather the total number of nodes (not just participating ones).
if self._fds is None:
raise FedbiomedAggregatorError("Cannot run SCAFFOLD aggregator: No Federated Dataset set")
total_nb_nodes = len(self._fds.node_ids())
# Compute the node-wise average of corrected gradients (ACG_i).
# i.e. (x^t - y_i^t}) / (K * eta_l)
local_state_updates: Dict[str, Mapping[str, Union[torch.Tensor, np.ndarray]]] = {}
for node_id, params in local_models.items():
local_state_updates[node_id] = {
key: (global_model[key] - val) / (self.nodes_lr[node_id][idx] * n_updates)
for idx, (key, val) in enumerate(params.items())
}
# Compute the shared state variable's update by averaging the former.
global_state_update = {
key: sum(state[key] for state in local_state_updates.values()) / total_nb_nodes
for key in global_model
}
# Compute the updated shared state variable.
# c^{t+1} = (1 - S/N)c^t + (1/N) sum_{i=1}^S ACG_i
share = 1 - len(local_models) / total_nb_nodes
global_state_new = {
key: share * self.global_state[key] + val
for key, val in global_state_update.items()
}
# Compute the difference between past and new shared state variables
# (ie c^t−c^{t+1} ).
global_state_diff = {
key: self.global_state[key] - val
for key, val in global_state_new.items()
}
# Compute the updated node-wise correction terms.
for node_id in self._fds.node_ids():
acg = local_state_updates.get(node_id, None)
# Case when the node did not participate in the round.
# d_i^{t+1} = d_i^t + c^t - c^{t+1}
if acg is None:
for key, val in self.nodes_correction_states[node_id].items():
self.nodes_correction_states[node_id][key] += global_state_diff[key]
# Case when the node participated in the round
# d_i^{t+1} = c_i^{t+1} - c^{t+1} = ACG_i - d_i^{t} - c^{t+1}
else:
for key, val in self.nodes_correction_states[node_id].items():
self.nodes_correction_states[node_id][key] = (
local_state_updates[node_id][key] - val - global_state_new[key]
)
# Assign the updated shared state.
self.global_state = global_state_new
Functions
federated_averaging(model_params, weights)
federated_averaging(model_params, weights)
Defines Federated Averaging (FedAvg) strategy for model aggregation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[torch.Tensor, np.ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
weights | List[float] | weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node). Items in the list must always sum up to 1 | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[torch.Tensor, np.ndarray]] | Final model with aggregated layers, as an OrderedDict object. |
Source code in fedbiomed/researcher/aggregators/functional.py
def federated_averaging(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
weights: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Defines Federated Averaging (FedAvg) strategy for model aggregation.
Args:
model_params: list that contains nodes' model parameters; each model is stored as an OrderedDict (maps
model layer name to the model weights)
weights: weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node).
Items in the list must always sum up to 1
Returns:
Final model with aggregated layers, as an OrderedDict object.
"""
assert len(model_params) > 0, 'An empty list of models was passed.'
assert len(weights) == len(model_params), 'List with number of observations must have ' \
'the same number of elements that list of models.'
# Compute proportions
proportions = [n_k / sum(weights) for n_k in weights]
return weighted_sum(model_params, proportions)
initialize(val)
initialize(val)
Initialize tensor or array vector.
Source code in fedbiomed/researcher/aggregators/functional.py
def initialize(val: Union[torch.Tensor, np.ndarray]) -> Tuple[str, Union[torch.Tensor, np.ndarray]]:
"""Initialize tensor or array vector. """
if isinstance(val, torch.Tensor):
return ('tensor' , torch.zeros_like(val).float())
elif isinstance(val, np.ndarray) or isinstance(val, list):
return ('array' , np.zeros(val.shape, dtype = float))
weighted_sum(model_params, proportions)
weighted_sum(model_params, proportions)
Performs weighted sum operation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[torch.Tensor, np.ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
proportions | List[float] | weights of all items whithin model_params's list | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[torch.Tensor, np.ndarray]] | Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum operation |
Source code in fedbiomed/researcher/aggregators/functional.py
def weighted_sum(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
proportions: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Performs weighted sum operation
Args:
model_params (List[Dict[str, Union[torch.Tensor, np.ndarray]]]): list that contains nodes' model parameters; each model is stored as an OrderedDict (maps
model layer name to the model weights)
proportions (List[float]): weights of all items whithin model_params's list
Returns:
Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum
operation
"""
# Empty model parameter dictionary
avg_params = copy.deepcopy(model_params[0])
for key, val in avg_params.items():
(t, avg_params[key] ) = initialize(val)
if t == 'tensor':
for model, weight in zip(model_params, proportions):
for key in avg_params.keys():
avg_params[key] += weight * model[key]
if t == 'array':
for key in avg_params.keys():
matr = np.array([ d[key] for d in model_params ])
avg_params[key] = np.average(matr, weights=np.array(proportions), axis=0)
return avg_params