Jean Zay : Utiliser la précision mixte pour optimiser l'apprentissage d'un modèle

Principe de fonctionnement

Le terme de précision fait référence ici à la manière de stocker des variables réelles en mémoire. La plupart du temps, les réels sont stockés sur 32 bits. On les appelle des réels à virgule flottante 32 bits, ou float32, et on parle de simple précision. Les réels peuvent également être stockés sur 64 ou 16 bits, selon le nombre de chiffres significatifs souhaité. On parle respectivement de float64 ou double précision, et de float16 ou semi précision. Les frameworks TensorFlow et PyTorch utilisent par défaut la simple précision, mais il est possible d'exploiter la semi précision pour optimiser l'étape d'apprentissage.

On parle de précision mixte quand on fait cohabiter plusieurs précisions dans le même modèle lors des étapes de propagation et de rétro-propagation. Autrement dit, quand on réduit la précision de certaines variables du modèle. Cette technique est possible grâce aux Tensor Cores disponibles sur les GPU NVIDIA V100 de Jean Zay. Ces Tensor Cores permettent d'effectuer efficacement des opérations impliquant à la fois des variables en float16 et des variables en float32.

Il a été montré empiriquement que, même en réduisant la précision de certaines variables, on obtient un apprentissage de modèle équivalent en performance (loss, accuracy) tant que les variables “sensibles” conservent la précision float32 par défaut, et ce pour tous les “grands types” de modèles actuels. Dans les frameworks TensorFlow et PyTorch, le caractère “sensible” des variables est déterminé automatiquement grâce à la fonctionnalité Automatic Mixed Precision, ou AMP.

Remarque : la précision mixte est une technique d'optimisation de l'apprentissage. À la fin de celui-ci, le modèle entraîné est reconverti en float32, sa précision initiale.

Intérêt d'implémenter la précision mixte

On remarquera plusieurs avantages à exploiter la précision mixte :

  • Le modèle occupe moins de place en mémoire puisqu'on divise par deux la taille de certaines variables.
  • Les variables étant plus rapides à transférer en mémoire, la bande passante est moins sollicitée.
  • Les opérations sont grandement accélérées (de l'ordre de 2x ou 3x) car, les variables étant plus petites, les calculs sont plus rapides à effectuer.

Il est donc conseillé d'utiliser cette technique en particulier sur des modèles :

  • lourds en mémoire (qui dépassent la taille mémoire d'un GPU, par exemple) ;
  • lourds en calculs (comme les réseaux convolutionnels de grande taille).

Plus le modèle est lourd (en terme de mémoire et d'opérations) plus l'usage de la précision mixte sera efficace. Toutefois, même pour un petit modèle, il existe un gain de performance. En conclusion, c'est une bonne pratique sur Jean Zay d'implémenter la précision mixte dans tous les cas.

Enfin, comme documenté sur le site Nvidia les performances sont meilleures quand les dimensions du modèle (batch size, taille d'image, couches embedded, couches denses) sont des multiples de 8, en raison des spécificités matérielles des Tensor cores.

La figure suivante illustre les gains en mémoire et en temps de calcul fournis par l'usage de la précision mixte. Ces résultats ont été mesurés sur Jean Zay pour l'entraînement d'un modèle Resnet101 implémenté en PyTorch, et exécuté sur 1 GPU.

AMP sur Jean Zay

Gains (AMP factor) en mémoire et temps de calcul de l'AMP par taille de batch mesurés avec Resnet101 en PyTorch

Loss Scaling

Il est important de comprendre l'influence de la conversion en float16 de certaines variables du modèle et de ne pas oublier d'implémenter la technique de Loss Scaling lorsque l'on utilise la précision mixte.

Source : Mixed Precision Training

En effet, la plage des valeurs représentables en précision float16 s'étend sur l'intervalle [2-24,215]. Or, comme on le voit sur la figure ci-contre, dans certains modèles, les valeurs des gradients sont bien inférieures au moment de la mise à jour des poids. Elles se trouvent ainsi en dehors de la zone représentable en précision float16 et sont réduites à une valeur nulle. Si l'on ne fait rien, les calculs risquent d'être faussés alors que la plage de valeurs représentables en float16 restera en grande partie inexploitée.

Pour éviter ce problème, on utilise une technique appelée Loss Scaling. Lors des itérations d'apprentissage, on multiplie la Loss d’entraînement par un facteur S pour déplacer les variables vers des valeurs plus élevées, représentables en float16. Il faudra ensuite les corriger avant la mise à jour des poids du modèle, en divisant les gradients de poids par le même facteur S. On rétablira ainsi les vraies valeurs de gradients.

Pour simplifier ce processus, il existe des solutions dans les frameworks TensorFlow et PyTorch pour mettre le Loss scaling en place.

En TensorFlow

Depuis la version 2.4, TensorFlow inclut une librairie dédiée à la mixed precision :

from tensorflow.keras import mixed_precision

Cette librairie permet d'instancier la précision mixte dans le backend de TensorFlow. L'instruction est la suivante :

mixed_precision.set_global_policy('mixed_float16')

On indique ainsi à TensorFlow de décider, à la création de chaque variable, quelle précision utiliser selon la politique implémentée dans la librairie. Il ne reste plus qu'à implémenter le Loss Scaling après la création de l'optimiseur :

opt = tf.optimizers.Adam(learning_rate)
opt = mixed_precision.LossScaleOptimizer(opt)

Remarque : si vous utilisez une boucle d'apprentissage personnalisée avec un GradientTape, il faut explicitement appliquer les étapes de la Scaled Loss (se référer à la page du guide TensorFlow).

En PyTorch

Depuis la version 1.6, PyTorch inclut des fonctions pour l'AMP (se référer aux exemples de la page Pytorch).

Pour mettre en place la précision mixte et le Loss Scaling, il faut ajouter quelques lignes :

from torch.cuda.amp import autocast, GradScaler
 
scaler = GradScaler()
 
for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
 
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
 
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()