Melatih model menggunakan TPU v5e

Dengan footprint 256 chip yang lebih kecil per Pod, TPU v5e dioptimalkan untuk menjadi produk bernilai tinggi untuk pelatihan, fine-tuning, dan penayangan transformer, text-to-image, dan Convolutional Neural Network (CNN). Untuk mengetahui informasi selengkapnya tentang penggunaan Cloud TPU v5e untuk penayangan, lihat Inferensi menggunakan v5e.

Untuk mengetahui informasi selengkapnya tentang hardware dan konfigurasi TPU v5e Cloud TPU, lihat TPU v5e.

Mulai

Bagian berikut menjelaskan cara mulai menggunakan TPU v5e.

Kuota permintaan

Anda memerlukan kuota untuk menggunakan TPU v5e untuk pelatihan. Ada berbagai jenis kuota untuk TPU on-demand, TPU yang dicadangkan, dan VM Spot TPU. Ada kuota terpisah yang diperlukan jika Anda menggunakan TPU v5e untuk inferensi. Untuk mengetahui informasi selengkapnya tentang kuota, lihat Kuota. Untuk meminta kuota TPU v5e, hubungi Bagian Penjualan Cloud.

Buat akun dan project Google Cloud

Anda memerlukan akun dan project untuk menggunakan Cloud TPU. Google Cloud Untuk mengetahui informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.

Buat Cloud TPU

Praktik terbaiknya adalah menyediakan Cloud TPU v5es sebagai resource dalam antrean menggunakan perintah queued-resource create. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource dalam antrean.

Anda juga dapat menggunakan Create Node API (gcloud compute tpus tpu-vm create) untuk menyediakan Cloud TPU v5es. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource TPU.

Untuk mengetahui informasi selengkapnya tentang konfigurasi v5e yang tersedia untuk pelatihan, lihat Jenis Cloud TPU v5e untuk pelatihan.

Penyiapan framework

Bagian ini menjelaskan proses penyiapan umum untuk pelatihan model kustom menggunakan JAX atau PyTorch dengan TPU v5e.

Untuk petunjuk penyiapan inferensi, lihat pengantar inferensi v5e.

Tentukan beberapa variabel lingkungan:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

Penyiapan untuk JAX

Jika Anda memiliki bentuk irisan yang lebih besar dari 8 chip, Anda akan memiliki beberapa VM dalam satu irisan. Dalam hal ini, Anda perlu menggunakan tanda --worker=all untuk menjalankan penginstalan di semua VM TPU dalam satu langkah tanpa menggunakan SSH untuk login ke setiap VM secara terpisah:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Deskripsi tanda perintah

Variabel Deskripsi
TPU_NAME ID teks yang ditetapkan pengguna dari TPU yang dibuat saat permintaan resource dalam antrean dialokasikan.
PROJECT_ID Google Cloud Nama Project. Gunakan project yang ada atau buat project baru di Siapkan project Google Cloud
ZONE Lihat dokumen Region dan zona TPU untuk mengetahui zona yang didukung.
pekerja VM TPU yang memiliki akses ke TPU pokok.

Anda dapat menjalankan perintah berikut untuk memeriksa jumlah perangkat (output yang ditampilkan di sini dihasilkan dengan slice v5litepod-16). Kode ini menguji apakah semuanya diinstal dengan benar dengan memeriksa apakah JAX melihat TensorCore Cloud TPU dan dapat menjalankan operasi dasar:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

Outputnya akan seperti ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() menampilkan jumlah total chip dalam slice yang diberikan. jax.local_device_count() menunjukkan jumlah chip yang dapat diakses oleh satu VM dalam slice ini.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

Outputnya akan seperti ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Coba Tutorial JAX dalam dokumen ini untuk mulai melatih v5e menggunakan JAX.

Penyiapan untuk PyTorch

Perhatikan bahwa v5e hanya mendukung runtime PJRT dan PyTorch 2.1+ akan menggunakan PJRT sebagai runtime default untuk semua versi TPU.

Bagian ini menjelaskan cara mulai menggunakan PJRT di v5e dengan PyTorch/XLA menggunakan perintah untuk semua pekerja.

Menginstal dependensi

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai Menggunakan dan Rilis PyTorch/XLA.

Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

Jika Anda mendapatkan error saat menginstal wheel untuk torch, torch_xla, atau torchvision seperti pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, turunkan versi Anda dengan perintah ini:

pip3 install setuptools==62.1.0

Menjalankan skrip dengan PJRT

unset LD_PRELOAD

Berikut adalah contoh penggunaan skrip Python untuk melakukan perhitungan di VM v5e:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Tindakan ini menghasilkan output yang mirip dengan berikut ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Coba Tutorial PyTorch dalam dokumen ini untuk mulai pelatihan v5e menggunakan PyTorch.

Hapus TPU dan resource dalam antrean di akhir sesi Anda. Untuk menghapus resource yang diantrekan, hapus slice, lalu hapus resource yang diantrekan dalam 2 langkah:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Dua langkah ini juga dapat digunakan untuk menghapus permintaan resource yang diantrekan dan berada dalam status FAILED.

Contoh JAX/FLAX

Bagian berikut menjelaskan contoh cara melatih model JAX dan FLAX di TPU v5e.

Melatih ImageNet di v5e

Tutorial ini menjelaskan cara melatih ImageNet di v5e menggunakan data input palsu. Jika Anda ingin menggunakan data sebenarnya, lihat file README di GitHub.

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Contoh: tpu-service-account@PROJECT_ID.iam.

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda dapat melakukan SSH ke VM TPU setelah resource yang diantrekan berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika QueuedResource dalam status ACTIVE, outputnya akan mirip dengan berikut ini:

     state: ACTIVE
    
  3. Instal JAX dan jaxlib versi terbaru:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Clone model ImageNet dan instal persyaratan yang sesuai:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. Untuk membuat data palsu, model memerlukan informasi tentang dimensi set data. Data ini dapat dikumpulkan dari metadata set data ImageNet:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

Melatih model

Setelah semua langkah sebelumnya selesai, Anda dapat melatih model.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

Menghapus TPU dan resource dalam antrean

Hapus TPU dan resource dalam antrean di akhir sesi Anda.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Model FLAX Hugging Face

Model Hugging Face yang diimplementasikan di FLAX dapat langsung digunakan di Cloud TPU v5e. Bagian ini memberikan petunjuk untuk menjalankan model populer.

Melatih ViT di Imagenette

Tutorial ini menunjukkan cara melatih model Vision Transformer (ViT) dari HuggingFace menggunakan set data Imagenette Fast AI di Cloud TPU v5e.

Model ViT adalah model pertama yang berhasil melatih encoder Transformer di ImageNet dengan hasil yang sangat baik dibandingkan dengan jaringan konvolusional. Untuk informasi selengkapnya, lihat referensi berikut:

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Contoh: tpu-service-account@PROJECT_ID.iam.

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat melakukan SSH ke VM TPU setelah resource dalam antrean Anda berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource dalam antrean berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  3. Instal JAX dan library-nya:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Download repositori Hugging Face dan instal persyaratan:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Download set data Imagenette:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Melatih model

Latih model dengan buffer yang dipetakan sebelumnya sebesar 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Menghapus TPU dan resource dalam antrean

Hapus TPU dan resource dalam antrean di akhir sesi Anda.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hasil benchmark ViT

Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput dengan berbagai jenis akselerator.

Jenis akselerator v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Ukuran batch global 32 128 512
Throughput (contoh/dtk) 263,40 429,34 470,71

Melatih Difusi pada Pokémon

Tutorial ini menunjukkan cara melatih model Stable Diffusion dari HuggingFace menggunakan set data Pokémon di Cloud TPU v5e.

Model Stable Diffusion adalah model text-to-image laten yang menghasilkan gambar fotorealistik dari input teks apa pun. Untuk informasi selengkapnya, lihat referensi berikut:

Siapkan

  1. Tetapkan variabel lingkungan untuk nama bucket penyimpanan Anda:

    export GCS_BUCKET_NAME=your_bucket_name
  2. Siapkan bucket penyimpanan untuk output model Anda:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Contoh: tpu-service-account@PROJECT_ID.iam.

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  4. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda dapat menjalankan SSH ke VM TPU setelah resource dalam antrean Anda berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource dalam antrean berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  5. Instal JAX dan library-nya.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. Download repositori HuggingFace dan instal persyaratan.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

Melatih model

Latih model dengan buffer yang dipetakan sebelumnya sebesar 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Pembersihan

Hapus TPU, resource dalam antrean, dan bucket Cloud Storage di akhir sesi.

  1. Menghapus TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. Hapus resource yang diantrekan:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Hapus bucket Cloud Storage:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Hasil benchmarking untuk difusi

Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput.

Jenis akselerator v5litepod-4 v5litepod-16 v5litepod-64
Langkah Pelatihan 1500 1500 1500
Ukuran batch global 32 64 128
Throughput (contoh/dtk) 36,53 43,71 49,36

PyTorch/XLA

Bagian berikut menjelaskan contoh cara melatih model PyTorch/XLA di TPU v5e.

Melatih ResNet menggunakan runtime PJRT

PyTorch/XLA bermigrasi dari XRT ke PjRt dari PyTorch 2.0+. Berikut adalah petunjuk terbaru untuk menyiapkan v5e untuk workload pelatihan PyTorch/XLA.

Siapkan
  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Contoh: tpu-service-account@PROJECT_ID.iam.

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat melakukan SSH ke VM TPU setelah QueuedResource Anda berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource dalam antrean berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  3. Instal dependensi khusus Torch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

    Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai Menggunakan dan Rilis PyTorch/XLA.

    Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

Latih model ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

Menghapus TPU dan resource dalam antrean

Hapus TPU dan resource dalam antrean di akhir sesi Anda.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Hasil benchmark

Tabel berikut menunjukkan throughput benchmark.

Jenis akselerator Throughput (contoh/detik)
v5litepod-4 4240 ex/s
v5litepod-16 10.810 ex/s
v5litepod-64 46.154 ex/s

Melatih ViT di v5e

Tutorial ini akan membahas cara menjalankan VIT di v5e menggunakan repositori HuggingFace di PyTorch/XLA pada set data cifar10.

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Contoh: tpu-service-account@PROJECT_ID.iam.

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda dapat melakukan SSH ke VM TPU setelah QueuedResource Anda berada dalam status ACTIVE:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource dalam antrean berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  3. Instal dependensi PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

    Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai Menggunakan dan Rilis PyTorch/XLA.

    Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

  4. Download repositori HuggingFace dan instal persyaratan.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Melatih model

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

Menghapus TPU dan resource dalam antrean

Hapus TPU dan resource dalam antrean di akhir sesi Anda.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hasil benchmark

Tabel berikut menunjukkan throughput tolok ukur untuk berbagai jenis akselerator.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Ukuran batch global 32 128 512
Throughput (contoh/dtk) 201 657 2.844