Menjalankan penghitungan di VM Cloud TPU menggunakan JAX

Dokumen ini memberikan pengantar singkat tentang cara menggunakan JAX dan Cloud TPU.

Sebelum memulai

Sebelum menjalankan perintah dalam dokumen ini, Anda harus membuat akun Google Cloud, menginstal Google Cloud CLI, dan mengonfigurasi perintah gcloud. Untuk informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.

Membuat VM Cloud TPU menggunakan gcloud

  1. Tentukan beberapa variabel lingkungan agar perintah lebih mudah digunakan.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Deskripsi variabel lingkungan

    PROJECT_ID
    ID project Google Cloud Anda.
    ACCELERATOR_TYPE
    Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    ZONE
    Zona tempat Anda berencana membuat Cloud TPU.
    RUNTIME_VERSION
    Versi runtime Cloud TPU. Untuk mengetahui informasi selengkapnya, lihat Gambar VM TPU.
    TPU_NAME
    Nama yang ditetapkan pengguna untuk Cloud TPU Anda.
  2. Buat VM TPU dengan menjalankan perintah berikut dari Cloud Shell atau terminal komputer tempat Google Cloud CLI diinstal.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

Menghubungkan ke VM Cloud TPU

Hubungkan ke VM TPU Anda melalui SSH menggunakan perintah berikut:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
--project=$PROJECT_ID \
--zone=$ZONE

Menginstal JAX di VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Pemeriksaan sistem

Verifikasi bahwa JAX dapat mengakses TPU dan dapat menjalankan operasi dasar:

  1. Mulai penafsir Python 3:

    (vm)$ python3
    >>> import jax
  2. Menampilkan jumlah core TPU yang tersedia:

    >>> jax.device_count()

Jumlah core TPU ditampilkan. Jumlah core yang ditampilkan bergantung pada versi TPU yang Anda gunakan. Untuk mengetahui informasi selengkapnya, lihat versi TPU.

Lakukan penghitungan:

>>> jax.numpy.add(1, 1)

Hasil penambahan numpy ditampilkan:

Output dari perintah:

Array(2, dtype=int32, weak_type=true)

Keluar dari penafsiran Python:

>>> exit()

Menjalankan kode JAX di VM TPU

Sekarang Anda dapat menjalankan kode JAX yang diinginkan. Contoh flax adalah tempat yang tepat untuk memulai menjalankan model ML standar di JAX. Misalnya, untuk melatih jaringan konvolusi MNIST dasar:

  1. Menginstal dependensi contoh Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Menginstal FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Menjalankan skrip pelatihan FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Skrip akan mendownload set data dan memulai pelatihan. Output skrip akan terlihat seperti ini:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Pembersihan

Agar akun Google Cloud Anda tidak dikenai biaya untuk resource yang digunakan di halaman ini, ikuti langkah-langkah berikut.

Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource.

  1. Putuskan koneksi dari instance Compute Engine, jika Anda belum melakukannya:

    (vm)$ exit
  2. Hapus Cloud TPU Anda.

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
      --project=$PROJECT_ID \
      --zone=$ZONE
  3. Pastikan resource telah dihapus dengan menjalankan perintah berikut. Pastikan TPU Anda tidak lagi tercantum. Proses penghapusan mungkin memerlukan waktu beberapa menit.

    $ gcloud compute tpus tpu-vm list \
      --zone=$ZONE

Catatan performa

Berikut beberapa detail penting yang sangat relevan dengan penggunaan TPU di JAX.

Padding

Salah satu penyebab paling umum untuk performa lambat di TPU adalah memperkenalkan padding yang tidak disengaja:

  • Array di Cloud TPU disusun dalam ubin. Hal ini memerlukan padding salah satu dimensi ke kelipatan 8, dan dimensi yang berbeda ke kelipatan 128.
  • Unit perkalian matriks berperforma terbaik dengan pasangan matriks besar yang meminimalkan kebutuhan padding.

dtype bfloat16

Secara default, perkalian matriks di JAX pada TPU menggunakan bfloat16 dengan akumulasi float32. Hal ini dapat dikontrol dengan argumen presisi pada panggilan fungsi jax.numpy yang relevan (matmul, dot, einsum, dll.). Pada khususnya:

  • precision=jax.lax.Precision.DEFAULT: menggunakan presisi bfloat16 campuran (tercepat)
  • precision=jax.lax.Precision.HIGH: menggunakan beberapa kartu MXU untuk mencapai presisi yang lebih tinggi
  • precision=jax.lax.Precision.HIGHEST: menggunakan lebih banyak kartu MXU untuk mencapai presisi float32 penuh

JAX juga menambahkan dtype bfloat16, yang dapat Anda gunakan untuk secara eksplisit mentransmisikan array ke bfloat16, misalnya, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Langkah berikutnya

Untuk informasi selengkapnya tentang Cloud TPU, lihat: