TensorFlow : Parallélisme de données multi-GPU et multi-nœuds

Cette page explique comment distribuer sur plusieurs GPU un modèle de réseau de neurones implémenté dans un code TensorFlow selon la méthode de parallélisme de données.

Un exemple applicatif est proposé sous forme d'un Notebook en bas de page pour vous permettre d'accéder à une implémentation fonctionnelle des explications ci-dessous.

Implémentation d'une stratégie de distribution

Pour distribuer un modèle en TensorFlow, on définit une stratégie de distribution en créant une instance de la classe tf.distribute.Strategy. Cette stratégie permet de contrôler la manière dont sont répartis les données et les calculs sur les GPU.

Choix de la stratégie

MultiWorkerMirroredStrategy

TensorFlow fournit plusieurs stratégies pré-implémentées. Dans cette documentation, nous présentons seulement la stratégie tf.distribute.MultiWorkerMirroredStrategy. Celle-ci a l'avantage d'être générique dans le sens où elle permet le passage aussi bien en multi-GPU qu'en multi-nœuds, sans perte de performance par rapport aux autres stratégies testées.

En utilisant la stratégie MultiWorkerMirroredStrategy, les variables du modèle sont répliquées ou mirrorées sur l'ensemble des GPU détectés. Chaque GPU traite une partie des données (mini-batch) et des opérations de réduction collectives sont effectuées pour agréger les tenseurs et mettre à jour les variables sur chaque GPU à chaque étape.

Environnement multi-workers

La MultiWorkerMirroredStrategy est une version multi-workers de la stratégie tf.distribute.MirroredStrategy. Elle s'exécute sur plusieurs tâches (ou workers), chacune d'entre elles étant assimilée à un nom d'hôte et un numéro de port. L'ensemble des tâches constitue un cluster sur lequel se base la stratégie de distribution pour synchroniser les GPU.

La notion de worker (et de cluster) permet notamment une exécution sur plusieurs nœuds de calcul. Chaque worker peut être associé à un ou plusieurs GPU. Sur Jean Zay, nous conseillons de définir un worker par GPU.

L'environnement multi-workers d'une exécution est automatiquement détectée à partir des variables Slurm définies dans votre script de soumission grâce à la classe tf.distribute.cluster_resolver.SlurmClusterResolver.

La stratégie MultiWorkerMirroredStrategy peut se baser sur deux types de protocoles de communications inter-GPU : gRPC ou NCCL. Sur Jean Zay, il est conseillé de demander l'utilisation du protocole de communication NCCL pour obtenir les meilleures performances.

Déclaration

La déclaration de la stratégie MultiWorkerMirroredStrategy se fait finalement en quelques lignes :

# build multi-worker environment from Slurm variables
cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=12345)           
 
# use NCCL communication protocol
implementation = tf.distribute.experimental.CommunicationImplementation.NCCL
communication_options = tf.distribute.experimental.CommunicationOptions(implementation=implementation) 
 
# declare distribution strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy(cluster_resolver=cluster_resolver,
                                                     communication_options=communication_options) 

Remarque : sur Jean Zay, vous pouvez exploiter les numéros de port compris entre 10000 et 20000.

Attention : il y a actuellement une limitation TensorFlow sur la déclaration de la stratégie, celle-ci doit être faite avant tout autre appel à une opération TensorFlow.

Intégration au modèle d'apprentissage

Pour répliquer un modèle sur plusieurs GPU, celui-ci doit être créé dans le contexte strategy.scope().

La méthode .scope() fournit un gestionnaire de contexte qui capture les variables TensorFlow et les communique à chaque GPU en fonction de la stratégie choisie. On y déclare les éléments qui créent les variables relatives au modèle : le chargement d'un modèle enregistré, la déclaration d'un modèle, la fonction model.compile(), l'optimiseur,…

Voici un exemple de déclaration d'un modèle à répliquer sur l'ensemble des GPU :

# get total number of workers 
n_workers = int(os.environ['SLURM_NTASKS'])
 
# define batch size
batch_size_per_gpu = 64
global_batch_size = batch_size_per_gpu * n_workers
 
# model building/compiling need to be within `strategy.scope()`
with strategy.scope():
 
  multi_worker_model = tf.keras.Sequential([
    tf.keras.Input(shape=(28, 28)),
    tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
    #...
    ])
 
  multi_worker_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.001*n_workers),      
    metrics=['accuracy'])

Remarque : l'utilisation de certains optimiseurs comme SGD nécessite un ajustement du learning rate proportionnel à la taille de batch globale, et donc au nombre de GPU.

Entraînement distribué d'un modèle de type tf.keras.Model

La fonction model.fit() de la librairie Keras prend en charge automatiquement la distribution de l'entraînement selon la stratégie choisie. L'entraînement est donc lancé de manière habituelle. Par exemple :

multi_worker_model.fit(train_dataset, epochs=10, steps_per_epoch=100)

Remarques :

  • La distribution des données d'entrées sur les différents processus (data sharding) est gérée automatiquement au sein de la fonction model.fit().
  • L'étape d'évaluation se fait automatiquement en mode distribué en passant le dataset d'évaluation à la fonction model.fit() :
    multi_worker_model.fit(train_dataset, epochs=3, steps_per_epoch=100, validation_data=valid_dataset)

Entraînement distribué d'un modèle personnalisé (Custom Training Loop)

Dans un modèle personnalisé, il faut prendre en charge la parallélisation dans la définition de l'étape d'entraînement train_step.

  1. L'étape d'entraînement doit être lancée en parallèle grâce à la fonction strategy.run(fn=step_fn). Celle-ci exécutera le contenu de la fonction step_fn sur chaque GPU.
  2. Le calcul de la loss doit se faire localement sur chaque GPU en fonction de la taille de batch globale. Il est donc nécessaire de :
    1. désactiver l'opération de réduction utilisée par défaut dans les fonctions de calcul de la loss par batch :
      reduction = tf.keras.losses.Reduction.NONE
      loss_per_batch = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,reduction=reduction)(y, predictions) 
    2. expliciter la division par la taille de batch globale dans le calcul de la loss par GPU :
      loss = tf.nn.compute_average_loss(loss_per_batch, global_batch_size=global_batch_size)
  3. Une fois la loss définie localement sur chaque GPU, la loss totale est calculée grâce à l'opération collective de réduction strategy.reduce().

Voici un exemple de définition d'une étape d'entraînement pour une exécution distribuée :

@tf.function
def train_step(iterator): # ---------------------------------------------------- training step 
 
  def step_fn(inputs): # ------------------------------------------------------- per-GPU step
 
    x, y = inputs
    with tf.GradientTape() as tape:
 
      # compute predictions
      predictions = multi_worker_model(x, training=True)
 
      # compute the loss for each batch between the labels and predictions
      # IMPORTANT: set reduction to NONE so we can do the reduction afterwards and divide by global batch size
      reduction = tf.keras.losses.Reduction.NONE
      losses_per_batch = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                                       reduction=reduction)(y, predictions)
 
      # sum losses per batch and divide by global batch size
      loss = tf.nn.compute_average_loss(losses_per_batch, global_batch_size=global_batch_size)
 
    # compute and apply gradients
    grads = tape.gradient(loss, multi_worker_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, multi_worker_model.trainable_variables))
 
    # compute accuracy
    train_accuracy.update_state(y, predictions)
 
    return loss
 
  # `strategy.run()` invokes step_fn on each GPU
  loss_per_gpu = strategy.run(step_fn, args=(next(iterator),))
 
  # `strategy.reduce()` reduces given value across GPUs and return result on current device
  return strategy.reduce(tf.distribute.ReduceOp.SUM, loss_per_gpu, axis=None)

Remarques:

  1. La distribution des données d'entrée sur les différents processus (data sharding) se fait en amont grâce à l'appel à la fonction strategy.experimental_distribute_dataset (l'appel se fera dans le contexte strategy.scope()) :
    multi_worker_dataset = strategy.experimental_distribute_dataset(train_dataset)
  2. L'étape de validation peut être parallélisée de manière similaire, en distribuant les données d'entrées et en redéfinissant la fonction de validation test_step.

Configuration de l'environnement de calcul

Modules Jean Zay

Pour obtenir des performances optimales en exécution distribuée sur Jean Zay, il faut charger l'un des modules suivants :

  • tensorflow-gpu/py3/2.4.0-noMKL
  • tensorflow-gpu/py3/2.5.0+nccl-2.8.3
  • n'importe quel module de version ≥ 2.6.0

Attention : dans les autres environnements Jean Zay, la librairie TensorFlow a été compilée avec certaines options qui peuvent entraîner des pertes de performance non négligeables.

Configuration de la réservation Slurm

Une réservation Slurm correcte contient autant de tâches Slurm que de workers car le script python doit être exécuté sur chaque worker. Nous associons ici un worker par GPU, donc autant de tâches Slurm que de GPU.

Attention : TensorFlow cherche par défaut à utiliser le protocole HTTP. Pour empêcher cela, il faut désactiver le proxy HTTP de Jean Zay en supprimant les variables d'environnement http_proxy, https_proxy, HTTP_PROXY et HTTPS_PROXY.

Voici un exemple de réservation sur 2 nœuds quadri-GPU :

#!/bin/bash
#SBATCH --job-name=tf_distributed     
#SBATCH --nodes=2 #------------------------- number of nodes 
#SBATCH --ntasks=8 #------------------------ number of tasks / workers
#SBATCH --gres=gpu:4 #---------------------- number of GPUs per node
#SBATCH --cpus-per-task=10           
#SBATCH --hint=nomultithread          
#SBATCH --time=01:00:00              
#SBATCH --output=tf_distributed%j.out 
#SBATCH --error=tf_distributed%j.err  
 
# deactivate the HTTP proxy
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
 
# activate the environment
module purge
module load tensorflow-gpu/py3/2.6.0
 
srun python3 -u script.py

Exemple d'application

Exécution multi-GPU et multi-nœuds avec la MultiWorkerMirroredStrategy

Un exemple sous forme de Notebook se trouve dans $DSDIR/examples_IA/Tensorflow_parallel/Example_DataParallelism_TensorFlow.ipynb sur Jean-Zay. Vous pouvez aussi le télécharger en cliquant sur ce lien.

L'exemple est un Notebook dans lequel les entraînements présentés ci-dessus sont implémentés et exécutés sur 1, 2, 4 et 8 GPU. Les exemples se basent sur le modèle ResNet50 et la base de données CIFAR-10.

Vous devez d'abord récupérer le Notebook dans votre espace personnel (par exemple sur votre $WORK) :

$ cp $DSDIR/examples_IA/Tensorflow_parallel/Example_DataParallelism_TensorFlow.ipynb $WORK

Vous pouvez ensuite exécuter le Notebook à partir d'une machine frontale de Jean Zay en chargeant préalablement un module TensorFlow. Par exemple :

$ module load tensorflow-gpu/py3/2.6.0
$ idrlab --notebook-dir=$WORK

Sources et documentation