Resnet50 auf Cloud TPU mit PyTorch trainieren

In dieser Anleitung erfahren Sie, wie Sie das ResNet-50-Modell auf einem Cloud TPU-Gerät mit PyTorch trainieren. Sie können dasselbe Muster auf andere TPU-optimierte Bildklassifikationsmodelle anwenden, die PyTorch und das ImageNet-Dataset verwenden.

Das Modell in dieser Anleitung basiert auf dem Framework Deep Residual Learning for Image Recognition, in dem erstmalig die Residualnetzwerkarchitektur (ResNet-Architektur) eingeführt wurde. In der Anleitung wird die 50-Layer-Variante ResNet-50 verwendet und das Training des Modells mit PyTorch/XLA veranschaulicht.

TPU-VM erstellen

  1. Öffnen Sie ein Cloud Shell-Fenster.

    Cloud Shell öffnen

  2. TPU-VM erstellen

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Stellen Sie eine SSH-Verbindung zu Ihrer WordPress-VM her.

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Installieren Sie PyTorch/XLA auf Ihrer TPU-VM:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Klonen Sie das PyTorch/XLA-GitHub-Repository.

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Führen Sie das Trainings-Script mit fiktiven Daten aus.

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1