fedbiomed.node.round

Module: fedbiomed.node.round

implementation of Round class of the node component

Classes

Round

CLASS
Round(
    model_kwargs=None,
    training_kwargs=None,
    training=True,
    dataset=None,
    training_plan_url=None,
    training_plan_class=None,
    params_url=None,
    job_id=None,
    researcher_id=None,
    history_monitor=None,
    aggregator_args=None,
    node_args=None,
    round_number=0,
    dlp_and_loading_block_metadata=None,
)

This class represents the training part execute by a node in a given round

Parameters:

Name Type Description Default
model_kwargs dict

contains model args

None
training_kwargs dict

contains training arguments

None
dataset dict

dataset details to use in this round. It contains the dataset name, dataset's id, data path, its shape, its description...

None
training_plan_url str

url from which to download training plan file

None
training_plan_class str

name of the training plan (eg 'MyTrainingPlan')

None
params_url str

url from which to upload/download model params

None
job_id str

job id

None
researcher_id str

researcher id

None
history_monitor HistoryMonitor

Sends real-time feed-back to end-user during training

None
correction_state

correction state applied in case of SCAFFOLD aggregation strategy

required
node_args Union[dict, None]

command line arguments for node. Can include: - gpu (bool): propose use a GPU device if any is available. - gpu_num (Union[int, None]): if not None, use the specified GPU device instead of default GPU device if this GPU device is available. - gpu_only (bool): force use of a GPU device if any available, even if researcher doesn't request for using a GPU.

None
Source code in fedbiomed/node/round.py
def __init__(self,
             model_kwargs: dict = None,
             training_kwargs: dict = None,
             training: bool = True,
             dataset: dict = None,
             training_plan_url: str = None,
             training_plan_class: str = None,
             params_url: str = None,
             job_id: str = None,
             researcher_id: str = None,
             history_monitor: HistoryMonitor = None,
             aggregator_args: dict = None,
             node_args: Union[dict, None] = None,
             round_number: int = 0,
             dlp_and_loading_block_metadata: Optional[Tuple[dict, List[dict]]] = None):

    """Constructor of the class

    Args:
        model_kwargs: contains model args
        training_kwargs: contains training arguments
        dataset: dataset details to use in this round. It contains the dataset name, dataset's id,
            data path, its shape, its description...
        training_plan_url: url from which to download training plan file
        training_plan_class: name of the training plan (eg 'MyTrainingPlan')
        params_url: url from which to upload/download model params
        job_id: job id
        researcher_id: researcher id
        history_monitor: Sends real-time feed-back to end-user during training
        correction_state: correction state applied in case of SCAFFOLD aggregation strategy
        node_args: command line arguments for node. Can include:
            - `gpu (bool)`: propose use a GPU device if any is available.
            - `gpu_num (Union[int, None])`: if not None, use the specified GPU device instead of default
                GPU device if this GPU device is available.
            - `gpu_only (bool)`: force use of a GPU device if any available, even if researcher
                doesn't request for using a GPU.
    """

    self._use_secagg: bool = False
    self.dataset = dataset
    self.training_plan_url = training_plan_url
    self.training_plan_class = training_plan_class
    self.params_url = params_url
    self.job_id = job_id
    self.researcher_id = researcher_id
    self.history_monitor = history_monitor
    self.aggregator_args: Optional[Dict[str, Any]] = aggregator_args

    self.tp_security_manager = TrainingPlanSecurityManager()
    self.node_args = node_args
    self.repository = Repository(environ['UPLOADS_URL'], environ['TMP_DIR'], environ['CACHE_DIR'])
    self.training_plan = None
    self.training = training
    self._dlp_and_loading_block_metadata = dlp_and_loading_block_metadata

    self.training_kwargs = training_kwargs
    self.model_arguments = model_kwargs
    self.testing_arguments = None
    self.loader_arguments = None
    self.training_arguments = None
    self._secagg_crypter = SecaggCrypter()
    self._secagg_clipping_range = None
    self._round = round_number
    self._biprime = None
    self._servkey = None

Attributes

aggregator_args instance-attribute
aggregator_args: Optional[Dict[str, Any]] = aggregator_args
dataset instance-attribute
dataset = dataset
history_monitor instance-attribute
history_monitor = history_monitor
job_id instance-attribute
job_id = job_id
loader_arguments instance-attribute
loader_arguments = None
model_arguments instance-attribute
model_arguments = model_kwargs
node_args instance-attribute
node_args = node_args
params_url instance-attribute
params_url = params_url
repository instance-attribute
repository = Repository(
    environ["UPLOADS_URL"],
    environ["TMP_DIR"],
    environ["CACHE_DIR"],
)
researcher_id instance-attribute
researcher_id = researcher_id
testing_arguments instance-attribute
testing_arguments = None
tp_security_manager instance-attribute
tp_security_manager = TrainingPlanSecurityManager()
training instance-attribute
training = training
training_arguments instance-attribute
training_arguments = None
training_kwargs instance-attribute
training_kwargs = training_kwargs
training_plan instance-attribute
training_plan = None
training_plan_class instance-attribute
training_plan_class = training_plan_class
training_plan_url instance-attribute
training_plan_url = training_plan_url

Functions

download_aggregator_args()

Retrieves aggregator arguments, that are sent through file exchange service

Returns:

Type Description
Tuple[bool, str]

Tuple[bool, str]: a tuple containing: a bool that indicates the success of operation a string containing the error message

Source code in fedbiomed/node/round.py
def download_aggregator_args(self) -> Tuple[bool, str]:
    """Retrieves aggregator arguments, that are sent through file exchange service

    Returns:
        Tuple[bool, str]: a tuple containing:
            a bool that indicates the success of operation
            a string containing the error message
    """
    # download heavy aggregator args (if any)

    if self.aggregator_args is not None:

        for arg_name, aggregator_arg in self.aggregator_args.items():
            if isinstance(aggregator_arg, dict):
                url = aggregator_arg.get('url', False)

                if any((url, arg_name)):
                    # if both `filename` and `arg_name` fields are present, it means that parameters
                    # should be retrieved using file
                    # exchanged system
                    success, param_path, error_msg = self.download_file(url, f"{arg_name}_{uuid.uuid4()}.mpk")
                    if not success:
                        return success, error_msg
                    else:
                        # FIXME: should we load parameters here or in the training plan
                        self.aggregator_args[arg_name] = {'param_path': param_path,
                                                          # 'params': training_plan.load(param_path,
                                                          # update_model=True)
                                                          }
                    self.aggregator_args[arg_name] = Serializer.load(param_path)
        return True, ''
    else:
        return True, "no file downloads required for aggregator args"
download_file(url, file_path)

Downloads file from file exchange system

Parameters:

Name Type Description Default
url str

url used to download file

required
file_path str

file path used to store the downloaded content

required

Returns:

Type Description
Tuple[bool, str, str]

Tuple[bool, str, str]: tuple that contains: bool that indicates the success of the download str that returns the complete path file str containing the error message (if any). Returns empty string if operation successful.

Source code in fedbiomed/node/round.py
def download_file(self, url: str, file_path: str) -> Tuple[bool, str, str]:
    """Downloads file from file exchange system

    Args:
        url (str): url used to download file
        file_path (str): file path used to store the downloaded content

    Returns:
        Tuple[bool, str, str]: tuple that contains:
            bool that indicates the success of the download
            str that returns the complete path file
            str containing the error message (if any). Returns empty
            string if operation successful.
    """

    status, params_path = self.repository.download_file(url, file_path)

    if (status != 200) or params_path is None:

        error_message = f"Cannot download param file: {url}"
        return False, '', error_message
    else:
        return True, params_path, ''
initialize_validate_training_arguments()

Validates and separates training argument for experiment round

Source code in fedbiomed/node/round.py
def initialize_validate_training_arguments(self) -> None:
    """Validates and separates training argument for experiment round"""

    self.training_arguments = TrainingArgs(self.training_kwargs, only_required=False)
    self.testing_arguments = self.training_arguments.testing_arguments()
    self.loader_arguments = self.training_arguments.loader_arguments()
run_model_training(secagg_arguments=None)

This method downloads training plan file; then runs the training of a model and finally uploads model params to the file repository

Parameters:

Name Type Description Default
secagg_arguments Union[Dict, None]
  • secagg_servkey_id: Secure aggregation Servkey context id. None means that the parameters are not going to be encrypted
  • secagg_biprime_id: Secure aggregation Biprime context ID.
  • secagg_random: Float value to validate secure aggregation on the researcher side
None

Returns:

Type Description
Dict[str, Any]

Returns the corresponding node message, training reply instance

Source code in fedbiomed/node/round.py
def run_model_training(
        self,
        secagg_arguments: Union[Dict, None] = None,
) -> Dict[str, Any]:
    """This method downloads training plan file; then runs the training of a model
    and finally uploads model params to the file repository

    Args:
        secagg_arguments:
            - secagg_servkey_id: Secure aggregation Servkey context id. None means that the parameters
                are not going to be encrypted
            - secagg_biprime_id: Secure aggregation Biprime context ID.
            - secagg_random: Float value to validate secure aggregation on the researcher side

    Returns:
        Returns the corresponding node message, training reply instance
    """
    is_failed = False

    # Validate secagg status. Raises error if the training request is compatible with
    # secure aggregation settings

    secagg_arguments = {} if secagg_arguments is None else secagg_arguments
    self._use_secagg = self._configure_secagg(
        secagg_servkey_id=secagg_arguments.get('secagg_servkey_id'),
        secagg_biprime_id=secagg_arguments.get('secagg_biprime_id'),
        secagg_random=secagg_arguments.get('secagg_random')
    )

    # Initialize and validate requested experiment/training arguments
    try:
        self.initialize_validate_training_arguments()
    except FedbiomedUserInputError as e:
        return self._send_round_reply(success=False, message=repr(e))
    except Exception as e:
        msg = 'Unexpected error while validating training argument'
        logger.debug(f"{msg}: {repr(e)}")
        return self._send_round_reply(success=False, message=f'{msg}. Please contact system provider')
    try:
        # module name cannot contain dashes
        import_module = 'training_plan_' + str(uuid.uuid4().hex)
        status, _ = self.repository.download_file(self.training_plan_url,
                                                  import_module + '.py')

        if status != 200:
            error_message = "Cannot download training plan file: " + self.training_plan_url
            return self._send_round_reply(success=False, message=error_message)
        else:
            if environ["TRAINING_PLAN_APPROVAL"]:
                approved, training_plan_ = self.tp_security_manager.check_training_plan_status(
                    os.path.join(environ["TMP_DIR"], import_module + '.py'),
                    TrainingPlanApprovalStatus.APPROVED)

                if not approved:
                    error_message = f'Requested training plan is not approved by the node: {environ["NODE_ID"]}'
                    return self._send_round_reply(success=False, message=error_message)
                else:
                    logger.info(f'Training plan has been approved by the node {training_plan_["name"]}')

        if not is_failed:

            success, params_path, error_msg = self.download_file(self.params_url, f"my_model_{uuid.uuid4()}.mpk")
            if success:
                # retrieving aggregator args
                success, error_msg = self.download_aggregator_args()

            if not success:
                return self._send_round_reply(success=False, message=error_msg)

    except Exception as e:
        is_failed = True
        # FIXME: this will trigger if model is not approved by node
        error_message = f"Cannot download training plan files: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)

    # import module, declare the training plan, load parameters
    try:
        sys.path.insert(0, environ['TMP_DIR'])
        module = importlib.import_module(import_module)
        train_class = getattr(module, self.training_plan_class)
        self.training_plan = train_class()
        sys.path.pop(0)
    except Exception as e:
        error_message = f"Cannot instantiate training plan object: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)

    try:
        self.training_plan.post_init(model_args=self.model_arguments,
                                     training_args=self.training_arguments,
                                     aggregator_args=self.aggregator_args)
    except Exception as e:
        error_message = f"Can't initialize training plan with the arguments: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)

    # import model params into the training plan instance
    try:
        params = Serializer.load(params_path)["model_weights"]
        self.training_plan.set_model_params(params)
    except Exception as e:
        error_message = f"Cannot initialize model parameters: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)

    # Split training and validation data
    try:
        self._set_training_testing_data_loaders()
    except FedbiomedError as e:
        error_message = f"Can not create validation/train data: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)
    except Exception as e:
        error_message = f"Undetermined error while creating data for training/validation. Can not create " \
                        f"validation/train data: {repr(e)}"
        return self._send_round_reply(success=False, message=error_message)

    # Validation Before Training
    if self.testing_arguments.get('test_on_global_updates', False) is not False:

        # Last control to make sure validation data loader is set.
        if self.training_plan.testing_data_loader is not None:
            try:
                self.training_plan.testing_routine(metric=self.testing_arguments.get('test_metric', None),
                                                   metric_args=self.testing_arguments.get('test_metric_args', {}),
                                                   history_monitor=self.history_monitor,
                                                   before_train=True)
            except FedbiomedError as e:
                logger.error(f"{ErrorNumbers.FB314}: During the validation phase on global parameter updates; "
                             f"{repr(e)}")
            except Exception as e:
                logger.error(f"Undetermined error during the testing phase on global parameter updates: "
                             f"{repr(e)}")
        else:
            logger.error(f"{ErrorNumbers.FB314}: Can not execute validation routine due to missing testing dataset"
                         f"Please make sure that `test_ratio` has been set correctly")

    # If training is activated.
    if self.training:
        if self.training_plan.training_data_loader is not None:
            try:
                results = {}
                rtime_before = time.perf_counter()
                ptime_before = time.process_time()
                self.training_plan.training_routine(history_monitor=self.history_monitor,
                                                    node_args=self.node_args)
                rtime_after = time.perf_counter()
                ptime_after = time.process_time()
            except Exception as e:
                error_message = f"Cannot train model in round: {repr(e)}"
                return self._send_round_reply(success=False, message=error_message)

        # Validation after training
        if self.testing_arguments.get('test_on_local_updates', False) is not False:

            if self.training_plan.testing_data_loader is not None:
                try:
                    self.training_plan.testing_routine(metric=self.testing_arguments.get('test_metric', None),
                                                       metric_args=self.testing_arguments.get('test_metric_args',
                                                                                              {}),
                                                       history_monitor=self.history_monitor,
                                                       before_train=False)
                except FedbiomedError as e:
                    logger.error(
                        f"{ErrorNumbers.FB314.value}: During the validation phase on local parameter updates; "
                        f"{repr(e)}")
                except Exception as e:
                    logger.error(f"Undetermined error during the validation phase on local parameter updates"
                                 f"{repr(e)}")
            else:
                logger.error(
                    f"{ErrorNumbers.FB314.value}: Can not execute validation routine due to missing testing "
                    f"dataset please make sure that test_ratio has been set correctly")

        sample_size = len(self.training_plan.training_data_loader.dataset)

        results["encrypted"] = False
        model_weights = self.training_plan.after_training_params(flatten=self._use_secagg)
        if self._use_secagg:
            logger.info("Encrypting model parameters. This process can take some time depending on model size.")

            encrypt = functools.partial(
                self._secagg_crypter.encrypt,
                num_nodes=len(self._servkey["parties"]) - 1,  # -1: don't count researcher
                current_round=self._round,
                key=self._servkey["context"]["server_key"],
                biprime=self._biprime["context"]["biprime"],
                weight=sample_size,
                clipping_range=secagg_arguments.get('secagg_clipping_range')
            )
            model_weights = encrypt(params=model_weights)
            results["encrypted"] = True
            results["encryption_factor"] = encrypt(params=[secagg_arguments["secagg_random"]])
            logger.info("Encryption is completed!")

        results['researcher_id'] = self.researcher_id
        results['job_id'] = self.job_id
        results['model_weights'] = model_weights
        results['node_id'] = environ['NODE_ID']
        results['optimizer_args'] = self.training_plan.optimizer_args()

        try:
            # TODO: add validation status to these results?
            # Dump the results to a msgpack file.
            filename = os.path.join(environ["TMP_DIR"], f"node_params_{uuid.uuid4()}.mpk")
            Serializer.dump(results, filename)
            # Upload that file to the remote repository.
            res = self.repository.upload_file(filename)
            logger.info("results uploaded successfully ")
        except Exception as exc:
            return self._send_round_reply(success=False, message=f"Cannot upload results: {exc}")

        # end : clean the namespace
        try:
            del self.training_plan
            del import_module
        except Exception as e:
            logger.debug(f'Exception raise while deleting training plan instance: {repr(e)}')

        return self._send_round_reply(success=True,
                                      timing={'rtime_training': rtime_after - rtime_before,
                                              'ptime_training': ptime_after - ptime_before},
                                      params_url=res['file'],
                                      sample_size=sample_size)
    else:
        # Only for validation
        return self._send_round_reply(success=True)