Menjalankan kode JAX pada slice Pod TPU

Sebelum menjalankan perintah dalam dokumen ini, pastikan Anda telah mengikuti petunjuk di Menyiapkan akun dan project Cloud TPU.

Setelah kode JAX berjalan di satu board TPU, Anda dapat menskalakan kode dengan menjalankannya di slice Pod TPU. Slice Pod TPU adalah beberapa board TPU yang terhubung satu sama lain melalui koneksi jaringan khusus berkecepatan tinggi. Dokumen ini adalah pengantar untuk menjalankan kode JAX di slice Pod TPU; untuk informasi yang lebih mendalam, lihat Menggunakan JAX di lingkungan multi-host dan multi-proses.

Jika ingin menggunakan NFS yang terpasang untuk penyimpanan data, Anda harus menetapkan Login OS untuk semua VM TPU di slice Pod. Untuk informasi selengkapnya, lihat Menggunakan NFS untuk penyimpanan data.

Membuat slice Pod Cloud TPU

  1. Buat beberapa variabel lingkungan:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Deskripsi variabel lingkungan

    PROJECT_ID
    Project ID Google Cloud Anda.
    ACCELERATOR_TYPE
    Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    ZONE
    Zona tempat Anda berencana membuat Cloud TPU.
    RUNTIME_VERSION
    Versi runtime Cloud TPU.
    TPU_NAME
    Nama yang ditetapkan pengguna untuk Cloud TPU Anda.
  2. Buat slice Pod TPU menggunakan perintah gcloud. Misalnya, untuk membuat slice Pod v5p-32, gunakan perintah berikut:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Menginstal JAX di slice Pod

Setelah membuat slice Pod TPU, Anda harus menginstal JAX di semua host dalam slice Pod TPU. Anda dapat melakukannya menggunakan perintah gcloud compute tpus tpu-vm ssh menggunakan parameter --worker=all dan --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Menjalankan kode JAX pada slice Pod

Untuk menjalankan kode JAX pada slice Pod TPU, Anda harus menjalankan kode di setiap host dalam slice Pod TPU. Panggilan jax.device_count() berhenti merespons hingga dipanggil di setiap host dalam slice Pod. Contoh berikut mengilustrasikan cara menjalankan penghitungan JAX pada slice Pod TPU.

Menyiapkan kode

Anda memerlukan gcloud versi >= 344.0.0 (untuk perintah scp). Gunakan gcloud --version untuk memeriksa versi gcloud, dan jalankan gcloud components upgrade, jika diperlukan.

Buat file bernama example.py dengan kode berikut:


import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Menyalin example.py ke semua VM pekerja TPU di slice Pod

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Jika sebelumnya belum pernah menggunakan perintah scp, Anda mungkin melihat error seperti berikut:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Untuk mengatasi error, jalankan perintah ssh-add seperti yang ditampilkan dalam pesan error dan jalankan kembali perintah tersebut.

Menjalankan kode di slice Pod

Luncurkan program example.py di setiap VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Output (dihasilkan dengan slice Pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Pembersihan

Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource.

  1. Putuskan koneksi dari instance Compute Engine, jika Anda belum melakukannya:

    (vm)$ exit

    Perintah Anda sekarang akan menjadi username@projectname, yang menunjukkan bahwa Anda berada di Cloud Shell.

  2. Hapus resource Cloud TPU dan Compute Engine Anda.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. Verifikasi bahwa resource telah dihapus dengan menjalankan gcloud compute tpus execution-groups list. Penghapusan mungkin memerlukan waktu beberapa menit. Output dari perintah berikut tidak boleh menyertakan resource apa pun yang dibuat dalam tutorial ini:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}