Jalankan penghitungan di VM Cloud TPU menggunakan JAX

Dokumen ini memberikan pengantar singkat tentang cara bekerja dengan JAX dan Cloud TPU.

Sebelum mengikuti panduan memulai ini, Anda harus membuat akun Google Cloud Platform, menginstal Google Cloud CLI, dan mengonfigurasi perintah gcloud. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan akun dan project Cloud TPU.

Menginstal Google Cloud CLI.

Google Cloud CLI berisi alat dan library untuk berinteraksi dengan produk dan layanan Google Cloud. Untuk mengetahui informasi selengkapnya, lihat Menginstal Google Cloud CLI.

Mengonfigurasi perintah gcloud

Jalankan perintah berikut untuk mengonfigurasi gcloud agar dapat menggunakan project Google Cloud Anda dan menginstal komponen yang diperlukan untuk pratinjau VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Mengaktifkan Cloud TPU API

  1. Aktifkan Cloud TPU API menggunakan perintah gcloud berikut di Cloud Shell. (Anda juga dapat mengaktifkannya dari Google Cloud Console).

    $ gcloud services enable tpu.googleapis.com
  2. Jalankan perintah berikut untuk membuat identitas layanan.

    $ gcloud beta services identity create --service tpu.googleapis.com

Buat VM Cloud TPU dengan gcloud

Dengan VM Cloud TPU, model dan kode Anda berjalan langsung di mesin host TPU. Anda menjalankan SSH langsung ke host TPU. Anda dapat menjalankan kode arbitrer, menginstal paket, melihat log, dan kode debug langsung di TPU Host.

  1. Buat VM TPU dengan menjalankan perintah berikut dari Cloud Shell atau terminal komputer tempat Google Cloud CLI terinstal.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base

    Kolom wajib diisi

    zone
    Zona tempat Anda berencana membuat Cloud TPU.
    accelerator-type
    Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    version
    Versi software Cloud TPU. Untuk semua jenis TPU, gunakan tpu-ubuntu2204-base.

Hubungkan ke VM Cloud TPU Anda

Jalankan SSH ke VM TPU Anda dengan menggunakan perintah berikut:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Kolom wajib diisi

tpu_name
Nama VM TPU yang terhubung dengan Anda.
zone
Zona tempat Anda membuat Cloud TPU.

Instal JAX di VM Cloud TPU Anda

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Pemeriksaan sistem

Pastikan JAX dapat mengakses TPU dan dapat menjalankan operasi dasar:

Mulai penafsir Python 3:

(vm)$ python3
>>> import jax

Menampilkan jumlah inti TPU yang tersedia:

>>> jax.device_count()

Jumlah inti TPU ditampilkan. Jika Anda menggunakan TPU v4, ID ini harus berupa 4. Jika Anda menggunakan TPU v2 atau v3, ID ini harus berupa 8.

Lakukan penghitungan sederhana:

>>> jax.numpy.add(1, 1)

Hasil penambahan numpy akan ditampilkan:

Output dari perintah:

Array(2, dtype=int32, weak_type=true)

Keluar dari penafsir Python:

>>> exit()

Menjalankan kode JAX di VM TPU

Sekarang Anda dapat menjalankan kode JAX yang diinginkan. Contoh flax adalah tempat yang tepat untuk memulai dengan menjalankan model ML standar di JAX. Misalnya, untuk melatih jaringan konvolusional MNIST dasar:

  1. Menginstal dependensi contoh Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instal FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Menjalankan skrip pelatihan FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Skrip akan mendownload set data dan memulai pelatihan. Output skrip akan terlihat seperti ini:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Pembersihan

Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource Anda.

  1. Putuskan koneksi dari instance Compute Engine jika Anda belum melakukannya:

    (vm)$ exit
  2. Hapus Cloud TPU Anda.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
  3. Verifikasi bahwa resource telah dihapus dengan menjalankan perintah berikut. Pastikan TPU Anda tidak lagi tercantum. Proses penghapusan mungkin memerlukan waktu beberapa menit.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b

Catatan Performa

Berikut ini beberapa detail penting yang sangat relevan dengan penggunaan TPU di JAX.

Padding

Salah satu penyebab paling umum untuk performa yang lambat pada TPU adalah memperkenalkan padding yang tidak disengaja:

  • Array di Cloud TPU disusun dalam tile. Hal ini memerlukan padding salah satu dimensi ke kelipatan 8, dan dimensi yang berbeda ke kelipatan 128.
  • Unit perkalian matriks berfungsi paling baik dengan pasangan matriks besar yang meminimalkan kebutuhan padding.

dtype bfloat16

Secara default, perkalian matriks dalam JAX pada TPU menggunakan bfloat16 dengan akumulasi float32. Hal ini dapat dikontrol dengan argumen presisi pada panggilan fungsi jax.numpy yang relevan (matmul, titik, einsum, dll.). Pada khususnya:

  • precision=jax.lax.Precision.DEFAULT: menggunakan presisi bfloat16 campuran (tercepat)
  • precision=jax.lax.Precision.HIGH: menggunakan beberapa penerusan MXU untuk mencapai presisi yang lebih tinggi
  • precision=jax.lax.Precision.HIGHEST: menggunakan lebih banyak penerusan MXU untuk mencapai presisi float32 penuh

JAX juga menambahkan dtype bfloat16 yang dapat Anda gunakan untuk mentransmisikan array secara eksplisit ke bfloat16, misalnya, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Menjalankan JAX di Colab

Saat Anda menjalankan kode JAX di notebook Colab, Colab akan otomatis membuat node TPU lama. TPU node memiliki arsitektur yang berbeda. Untuk informasi selengkapnya, lihat Arsitektur Sistem.

Langkah selanjutnya

Untuk mengetahui informasi selengkapnya tentang Cloud TPU, lihat: