Entraîner Resnet50 sur Cloud TPU avec PyTorch


Ce tutoriel explique comment entraîner le modèle ResNet-50 sur un appareil Cloud TPU avec PyTorch. La même procédure peut s'appliquer à d'autres modèles de classification d'image optimisés pour TPU, qui utilisent PyTorch et l'ensemble de données ImageNet.

Le modèle utilisé dans ce tutoriel est basé sur l'article Deep Residual Learning for Image Recognition (Deep learning résiduel pour la reconnaissance d'images), qui présente l'architecture de réseau résiduel (ResNet). Le tutoriel emploie la variante à 50 couches, ResNet-50, et illustre l'entraînement du modèle à l'aide de PyTorch/XLA.

Objectifs

  • Préparer l'ensemble de données
  • Exécuter la tâche d'entraînement
  • Vérifier les résultats

Coûts

Dans ce document, vous utilisez les composants facturables suivants de Google Cloud :

  • Compute Engine
  • Cloud TPU

Obtenez une estimation des coûts en fonction de votre utilisation prévue à l'aide du simulateur de coût. Les nouveaux utilisateurs de Google Cloud peuvent bénéficier d'un essai gratuit.

Avant de commencer

Avant de commencer ce tutoriel, vérifiez que votre projet Google Cloud est correctement configuré.

  1. Connectez-vous à votre compte Google Cloud. Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $ de crédits gratuits pour exécuter, tester et déployer des charges de travail.
  2. Dans Google Cloud Console, sur la page de sélection du projet, sélectionnez ou créez un projet Google Cloud.

    Accéder au sélecteur de projet

  3. Vérifiez que la facturation est activée pour votre projet Google Cloud.

  4. Dans Google Cloud Console, sur la page de sélection du projet, sélectionnez ou créez un projet Google Cloud.

    Accéder au sélecteur de projet

  5. Vérifiez que la facturation est activée pour votre projet Google Cloud.

  6. Ce tutoriel utilise des composants facturables de Google Cloud. Consultez la grille tarifaire de Cloud TPU pour estimer vos coûts. Veillez à nettoyer les ressources que vous avez créées lorsque vous avez terminé, afin d'éviter des frais inutiles.

Créer une VM TPU

  1. Ouvrez une fenêtre Cloud Shell.

    Ouvrir Cloud Shell

  2. Créer une VM TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central2-b \
    --project=your-project
    
  3. Connectez-vous à votre VM TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central2-b
    
  4. Installez PyTorch/XLA sur votre VM TPU:

    (vm)$ pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
    
  5. Clonez le dépôt GitHub PyTorch/XLA.

    (vm)$ git clone --depth=1 --branch r2.2 https://github.com/pytorch/xla.git
    
  6. Exécuter le script d'entraînement avec des données fictives

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
    

Si vous êtes en mesure d'entraîner le modèle à l'aide de données fictives, vous pouvez essayer d'effectuer l'entraînement sur des données réelles, par exemple ImageNet. Pour savoir comment télécharger ImageNet, consultez la page Télécharger ImageNet. Dans la commande du script d'entraînement, l'option --datadir spécifie l'emplacement de l'ensemble de données sur lequel l'entraînement doit être effectué. La commande suivante suppose que l'ensemble de données ImageNet se trouve dans ~/imagenet.

   (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py  --datadir=~/imagenet --batch_size=256 --num_epochs=1
   

Effectuer un nettoyage

Pour éviter que les ressources utilisées lors de ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

  1. Déconnectez-vous de la VM TPU:

    (vm) $ exit
    

    Votre invite devrait maintenant être username@projectname, indiquant que vous êtes dans Cloud Shell.

  2. Supprimez votre VM TPU.

    $ gcloud compute tpus tpu-vm delete resnet50-tutorial \
       --zone=us-central2-b
    

Étapes suivantes

Essayez les colabs PyTorch :