Pengantar Trillium (v6e)
v6e digunakan untuk merujuk ke Trillium dalam dokumentasi, TPU API, dan log ini. v6e mewakili TPU generasi ke-6 Google.
Dengan 256 chip per Pod, v6e memiliki banyak kesamaan dengan v5e. Sistem ini dioptimalkan untuk menjadi produk dengan nilai tertinggi untuk pelatihan, penyesuaian, dan penayangan transformer, teks ke gambar, dan jaringan saraf konvolusi (CNN).
Arsitektur sistem v6e
Untuk mengetahui informasi tentang konfigurasi Cloud TPU, lihat dokumentasi v6e.
Dokumen ini berfokus pada proses penyiapan untuk pelatihan model menggunakan framework JAX, PyTorch, atau TensorFlow. Dengan setiap framework, Anda dapat menyediakan TPU menggunakan resource yang diantrekan atau Google Kubernetes Engine (GKE). Penyiapan GKE dapat dilakukan menggunakan perintah XPK atau GKE.
Menyiapkan project Google Cloud
- Login ke Akun Google Anda. Jika Anda belum melakukannya, daftar untuk membuat akun baru.
- Di konsol Google Cloud, pilih atau buat project Cloud dari halaman pemilih project.
- Aktifkan penagihan untuk project Google Cloud Anda. Penagihan diperlukan untuk semua penggunaan Google Cloud.
- Instal komponen gcloud alpha.
Jalankan perintah berikut untuk menginstal versi terbaru komponen
gcloud
.gcloud components update
Aktifkan TPU API melalui perintah
gcloud
berikut di Cloud Shell. Anda juga dapat mengaktifkannya dari Konsol Google Cloud.gcloud services enable tpu.googleapis.com
Mengaktifkan izin dengan akun layanan TPU untuk Compute Engine API
Akun layanan memungkinkan layanan Cloud TPU mengakses layanan Google Cloud lainnya. Akun layanan yang dikelola pengguna adalah praktik Google Cloud yang direkomendasikan. Ikuti panduan ini untuk membuat dan memberikan peran. Peran berikut diperlukan:
- TPU Admin
- Storage Admin
- Penulis Log
- Penulis Metrik Pemantauan
a. Siapkan izin XPK dengan akun pengguna Anda untuk GKE: XPK.
Buat variabel lingkungan untuk project ID dan zona.
gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Buat identitas layanan untuk VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Kapasitas aman
Hubungi tim penjualan/akun dukungan Cloud TPU Anda untuk meminta kuota TPU dan menjawab pertanyaan apa pun tentang kapasitas.
Menyediakan lingkungan Cloud TPU
TPU v6e dapat disediakan dan dikelola dengan GKE, dengan GKE dan XPK (alat CLI wrapper melalui GKE), atau sebagai resource dalam antrean.
Prasyarat
- Pastikan project Anda memiliki kuota
TPUS_PER_TPU_FAMILY
yang cukup, yang menentukan jumlah maksimum chip yang dapat Anda akses dalam project Google Cloud. - v6e telah diuji dengan konfigurasi berikut:
- python
3.10
atau yang lebih baru - Versi software harian:
- JAX harian
0.4.32.dev20240912
- LibTPU harian
0.1.dev20240912+nightly
- JAX harian
- Versi software stabil:
- JAX + JAX Lib v0.4.35
- python
- Pastikan project Anda memiliki cukup kuota TPU untuk:
- Kuota VM TPU
- Kuota Alamat IP
- Kuota hyperdisk-balance
- Izin project pengguna
- Jika Anda menggunakan GKE dengan XPK, lihat Izin Konsol Cloud di akun pengguna atau layanan untuk mengetahui izin yang diperlukan untuk menjalankan XPK.
Variabel lingkungan
Di Cloud Shell, buat variabel lingkungan berikut:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Deskripsi flag perintah
Variabel | Deskripsi |
NODE_ID | ID TPU yang ditetapkan pengguna yang dibuat saat permintaan resource yang diantrekan dialokasikan. |
PROJECT_ID | Nama Project Google Cloud. Gunakan project yang ada atau buat project baru di |
ZONA | Lihat dokumen Region dan zona TPU untuk zona yang didukung. |
ACCELERATOR_TYPE | Lihat Jenis Akselerator. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Ini adalah alamat email untuk akun layanan yang dapat Anda temukan di
Konsol Google Cloud -> IAM -> Akun Layanan
Misalnya: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Jumlah slice yang akan dibuat (hanya diperlukan untuk Multislice) |
QUEUED_RESOURCE_ID | ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. |
VALID_DURATION | Durasi validitas permintaan resource yang diantrekan. |
NETWORK_NAME | Nama jaringan sekunder yang akan digunakan. |
NETWORK_FW_NAME | Nama firewall jaringan sekunder yang akan digunakan. |
Pengoptimalan performa jaringan
Untuk performa terbaik,gunakan jaringan dengan MTU 8.896 (unit transmisi maksimum).
Secara default, Virtual Private Cloud (VPC) hanya menyediakan MTU sebesar 1.460 byte yang akan memberikan performa jaringan yang kurang optimal. Anda dapat menetapkan MTU jaringan VPC ke nilai apa pun antara 1.300 byte dan 8.896 byte (inklusif). Ukuran MTU kustom umum adalah 1.500 byte (Ethernet standar) atau 8.896 byte (maksimum yang memungkinkan). Untuk mengetahui informasi selengkapnya, lihat Ukuran MTU jaringan VPC yang valid.
Untuk informasi selengkapnya tentang cara mengubah setelan MTU untuk jaringan yang ada atau default, lihat Mengubah setelan MTU jaringan VPC.
Contoh berikut membuat jaringan dengan 8.896 MTU
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME} export NETWORK_FW_NAME=${RESOURCE_NAME} export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \
Menggunakan multi-NIC (Opsi untuk Multislice)
Variabel lingkungan berikut diperlukan untuk subnet sekunder saat Anda menggunakan lingkungan Multislice.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Gunakan perintah berikut untuk membuat perutean IP kustom untuk jaringan dan subnet.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT}" \
--enable-logging
Setelah slice multi-jaringan dibuat, Anda dapat memvalidasi bahwa
kedua NIC sedang digunakan dengan menjalankan --command ifconfig
sebagai bagian
dari workload XPK. Kemudian, lihat output
yang dicetak dari workload XPK tersebut di log konsol Cloud dan periksa
apakah eth0 dan eth1 memiliki mtu=8896.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Pastikan eth0 dan eth1 memiliki mtu=8.896. cara untuk memverifikasi bahwa Anda memiliki multi-nic yang berjalan adalah dengan menjalankan perintah --command "ifconfig" sebagai bagian dari beban kerja XPK. Kemudian, lihat output yang dicetak dari workload xpk tersebut di log konsol cloud dan pastikan eth0 dan eth1 memiliki mtu=8896.
Peningkatan setelan TCP
Untuk TPU yang dibuat menggunakan antarmuka resource yang diantrekan, Anda dapat menjalankan perintah berikut untuk meningkatkan performa jaringan dengan mengubah setelan TCP default untuk rto_min
dan quickack
.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" --zone "${ZONE}" \ --command='ip route show | while IFS= read -r route; do if ! echo $route | \ grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \ --worker=all
Penyediaan dengan resource yang diantrekan (Cloud TPU API)
Kapasitas dapat disediakan menggunakan perintah create
resource antrean.
Buat permintaan resource yang diantrekan TPU.
Flag
--reserved
hanya diperlukan untuk resource yang dicadangkan, bukan untuk resource on demand.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved]
Jika permintaan resource dalam antrean berhasil dibuat, status dalam kolom "response" akan menjadi "WAITING_FOR_RESOURCES" atau "FAILED". Jika permintaan resource dalam antrean berada dalam status "WAITING_FOR_RESOURCES", resource dalam antrean telah dimasukkan ke dalam antrean dan akan disediakan jika kapasitas TPU memadai. Jika permintaan resource yang diantrekan berada dalam status "FAILED", alasan kegagalan akan ada dalam output. Masa berlaku permintaan resource dalam antrean akan berakhir jika v6e tidak disediakan dalam durasi yang ditentukan, dan statusnya menjadi "GAGAL". Lihat dokumentasi publik Resource dalam Antrean untuk mengetahui informasi selengkapnya.
Saat permintaan resource yang diantrekan berada dalam status "AKTIF", Anda dapat terhubung ke VM TPU menggunakan SSH. Gunakan perintah
list
ataudescribe
untuk membuat kueri status resource yang diantrekan.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Jika resource yang diantrekan berada dalam status "ACTIVE", output-nya akan mirip dengan berikut ini:
state: state: ACTIVE
Mengelola VM TPU Anda. Untuk mengetahui opsi guna mengelola VM TPU, lihat mengelola VM TPU.
Menghubungkan ke VM TPU menggunakan SSH
Anda dapat menginstal biner di setiap VM TPU dalam slice TPU dan menjalankan kode. Lihat bagian Jenis VM untuk menentukan jumlah VM yang akan dimiliki slice Anda.
Untuk menginstal biner atau menjalankan kode, Anda dapat menggunakan SSH untuk terhubung ke VM menggunakan perintah
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Untuk menggunakan SSH guna terhubung ke VM tertentu, gunakan flag
--worker
yang mengikuti indeks berbasis 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Jika memiliki bentuk slice lebih dari 8 chip, Anda akan memiliki beberapa VM dalam satu slice. Dalam hal ini, gunakan parameter
--worker=all
dan--command
dalam perintahgcloud alpha compute tpus tpu-vm ssh
untuk menjalankan perintah di semua VM secara bersamaan. Contoh:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Menghapus resource yang diantrekan
Menghapus resource yang diantrekan di akhir sesi atau menghapus permintaan resource yang diantrekan yang berada dalam status "GAGAL". Untuk menghapus resource dalam antrean, hapus slice, lalu permintaan resource dalam antrean dalam 2 langkah:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Menggunakan GKE dengan v6e
Jika menggunakan perintah GKE dengan v6e, Anda dapat menggunakan perintah Kubernetes atau XPK untuk menyediakan TPU dan melatih atau menayangkan model. Lihat Merencanakan TPU di GKE untuk mempelajari cara menggunakan GKE dengan TPU dan v6e.
Penyiapan framework
Bagian ini menjelaskan proses penyiapan umum untuk pelatihan model ML menggunakan framework JAX, PyTorch, atau TensorFlow. Anda dapat menyediakan TPU menggunakan resource dalam antrean atau GKE. Penyiapan GKE dapat dilakukan menggunakan perintah XPK atau Kubernetes.
Menyiapkan JAX menggunakan resource yang diantrekan
Instal JAX di semua VM TPU dalam satu atau beberapa slice secara bersamaan menggunakan
gcloud alpha compute tpus tpu-vm ssh
. Untuk Multislice, tambahkan --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Anda dapat menjalankan kode Python berikut untuk memeriksa jumlah core TPU yang tersedia di slice dan menguji apakah semuanya diinstal dengan benar (output yang ditampilkan di sini dihasilkan dengan slice v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Outputnya mirip dengan hal berikut ini:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() menunjukkan jumlah total chip dalam slice yang diberikan. jax.local_device_count() menunjukkan jumlah chip yang dapat diakses oleh satu VM dalam slice ini.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
&& pip install -r requirements.txt && pip install . '
Memecahkan masalah penyiapan JAX
Tips umum adalah mengaktifkan logging panjang di manifes beban kerja GKE Anda. Kemudian, berikan log ke dukungan GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Pesan error
no endpoints available for service 'jobset-webhook-service'
Error ini berarti set tugas tidak diinstal dengan benar. Periksa untuk melihat apakah Pod Kubernetes deployment jobset-controller-manager sedang berjalan. Untuk mengetahui informasi selengkapnya, lihat dokumentasi pemecahan masalah JobSet untuk mengetahui detailnya.
TPU initialization failed: Failed to connect
Pastikan versi node GKE Anda adalah 1.30.4-gke.1348000 atau yang lebih baru (GKE 1.31 tidak didukung).
Penyiapan untuk PyTorch
Bagian ini menjelaskan cara mulai menggunakan PJRT di v6e dengan PyTorch/XLA. Python 3.10 adalah versi Python yang direkomendasikan.
Menyiapkan PyTorch menggunakan GKE dengan XPK
Anda dapat menggunakan penampung Docker berikut dengan XPK yang telah menginstal dependensi PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Untuk membuat beban kerja XPK, gunakan perintah berikut:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
Menggunakan --base-docker-image
akan membuat image Docker baru dengan direktori kerja
saat ini yang di-build ke dalam Docker baru.
Menyiapkan PyTorch menggunakan resource yang diantrekan
Ikuti langkah-langkah berikut untuk menginstal PyTorch menggunakan resource yang diantrekan dan menjalankan skrip kecil di v6e.
Instal dependensi menggunakan SSH untuk mengakses VM.
Untuk Multislice, tambahkan --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Meningkatkan performa model dengan alokasi yang cukup besar dan sering
Untuk model yang memiliki alokasi yang cukup besar dan sering, kami telah mengamati bahwa penggunaan
tcmalloc
meningkatkan performa secara signifikan dibandingkan dengan
implementasi malloc
default,
sehingga malloc
default yang digunakan di VM TPU adalah tcmalloc
. Namun, bergantung pada
workload Anda (misalnya, dengan DLRM yang memiliki alokasi yang sangat besar untuk
tabel penyematan), tcmalloc
dapat menyebabkan pelambatan. Dalam hal ini, Anda dapat mencoba
menghapus setelan variabel berikut menggunakan malloc
default:
unset LD_PRELOAD
Gunakan skrip Python untuk melakukan penghitungan di VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Tindakan ini akan menghasilkan output yang mirip dengan berikut ini:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Penyiapan untuk TensorFlow
Untuk Pratinjau Publik v6e, hanya versi runtime tf-nightly yang didukung.
Anda dapat mereset tpu-runtime
dengan versi TensorFlow
yang kompatibel dengan v6e dengan menjalankan perintah berikut:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Gunakan SSH untuk mengakses worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Instal TensorFlow di pekerja-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Ekspor variabel lingkungan TPU_NAME
:
export TPU_NAME=v6e-16
Anda dapat menjalankan skrip Python berikut untuk memeriksa jumlah core TPU yang tersedia di slice Anda dan untuk menguji apakah semuanya diinstal dengan benar (output yang ditampilkan dibuat dengan slice v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
Outputnya mirip dengan hal berikut ini:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e dengan SkyPilot
Anda dapat menggunakan TPU v6e dengan SkyPilot. Gunakan langkah-langkah berikut untuk menambahkan informasi lokasi/harga terkait v6e ke SkyPilot.
Tambahkan kode berikut ke akhir
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Tentukan resource berikut dalam file YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Luncurkan cluster dengan TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Hubungkan ke TPU v6e menggunakan SSH:
ssh tpu_v6
Tutorial inferensi
Bagian berikut memberikan tutorial untuk menayangkan model MaxText dan PyTorch menggunakan JetStream, serta menayangkan model MaxDiffusion di TPU v6e.
MaxText di JetStream
Tutorial ini menunjukkan cara menggunakan JetStream untuk menayangkan model MaxText (JAX) di TPU v6e. JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA (TPU). Dalam tutorial ini, Anda akan menjalankan benchmark inferensi untuk model Llama2-7B.
Sebelum memulai
Buat TPU v6e dengan 4 chip:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Hubungkan ke TPU menggunakan SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Menjalankan tutorial
Untuk menyiapkan JetStream dan MaxText, mengonversi checkpoint model, dan menjalankan benchmark inferensi, ikuti petunjuk di repositori GitHub.
Pembersihan
Hapus TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
vLLM di PyTorch TPU
Berikut adalah tutorial sederhana yang menunjukkan cara memulai vLLM di VM TPU. Untuk contoh praktik terbaik kami dalam men-deploy vLLM di Trillium dalam produksi, kami akan memublikasikan panduan pengguna GKE dalam beberapa hari ke depan (nantikan kabar terbarunya).
Sebelum memulai
Buat TPU v6e dengan 4 chip:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Deskripsi flag perintah
Variabel Deskripsi NODE_ID ID TPU yang ditetapkan pengguna yang dibuat saat permintaan resource dalam antrean dialokasikan. PROJECT_ID Nama Project Google Cloud. Gunakan project yang ada atau buat project baru di ZONA Lihat dokumen Region dan zona TPU untuk zona yang didukung. ACCELERATOR_TYPE Lihat Jenis Akselerator. RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Ini adalah alamat email untuk akun layanan yang dapat Anda temukan di Konsol Google Cloud -> IAM -> Akun Layanan Misalnya: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com
Hubungkan ke TPU menggunakan SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Create a Conda environment
(Recommended) Create a new conda environment for vLLM:
conda create -n vllm python=3.10 -y conda activate vllm
Menyiapkan vLLM di TPU
Clone repositori vLLM dan buka direktori vLLM:
git clone https://github.com/vllm-project/vllm.git && cd vllm
Bersihkan paket torch dan torch-xla yang ada:
pip uninstall torch torch-xla -y
Instal PyTorch dan PyTorch XLA:
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Instal JAX dan Pallas:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Instal dependensi build lainnya:
pip install -r requirements-tpu.txt VLLM_TARGET_DEVICE="tpu" python setup.py develop sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Mendapatkan akses ke model
Anda harus menandatangani perjanjian izin untuk menggunakan keluarga model Llama3 di repo HuggingFace
Buat token Hugging Face baru jika Anda belum memilikinya:
- Klik Profil Anda > Setelan > Token Akses.
- Pilih New Token.
- Tentukan Nama pilihan Anda dan Peran minimal
Read
. - Pilih Buat token.
Salin token yang dihasilkan ke papan klip Anda, tetapkan sebagai variabel lingkungan, lalu autentikasi dengan huggingface-cli:
export TOKEN='' git config --global credential.helper store huggingface-cli login --token $TOKEN
Mendownload Data Benchmark
Buat direktori /data dan download set data ShareGPT dari Hugging Face.
mkdir ~/data && cd ~/data wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
Meluncurkan server vLLM
Perintah berikut mendownload bobot model dari
Hugging Face Model Hub
ke direktori /tmp VM TPU, mengompilasi berbagai bentuk input terlebih dahulu, dan menulis kompilasi model ke ~/.cache/vllm/xla_cache
.
Untuk mengetahui detail selengkapnya, lihat dokumen vLLM.
cd ~/vllm
vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &
Menjalankan Benchmark vLLM
Jalankan skrip benchmark vLLM:
python benchmarks/benchmark_serving.py \
--backend vllm \
--model "meta-llama/Meta-Llama-3.1-8B" \
--dataset-name sharegpt \
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 1000
Pembersihan
Hapus TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
PyTorch di JetStream
Tutorial ini menunjukkan cara menggunakan JetStream untuk menayangkan model PyTorch di TPU v6e. JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA (TPU). Dalam tutorial ini, Anda akan menjalankan benchmark inferensi untuk model Llama2-7B.
Sebelum memulai
Buat TPU v6e dengan 4 chip:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Hubungkan ke TPU menggunakan SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Menjalankan tutorial
Untuk menyiapkan JetStream-PyTorch, mengonversi checkpoint model, dan menjalankan benchmark inferensi, ikuti petunjuk di repositori GitHub.
Pembersihan
Hapus TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async
Inferensi MaxDiffusion
Tutorial ini menunjukkan cara menayangkan model MaxDiffusion di TPU v6e. Dalam tutorial ini, Anda akan membuat gambar menggunakan model Stable Diffusion XL.
Sebelum memulai
Buat TPU v6e dengan 4 chip:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Hubungkan ke TPU menggunakan SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Membuat lingkungan Conda
Buat direktori untuk Miniconda:
mkdir -p ~/miniconda3
Download skrip penginstal Miniconda:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
Instal Miniconda:
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
Hapus skrip penginstal Miniconda:
rm -rf ~/miniconda3/miniconda.sh
Tambahkan Miniconda ke variabel
PATH
Anda:export PATH="$HOME/miniconda3/bin:$PATH"
Muat ulang
~/.bashrc
untuk menerapkan perubahan pada variabelPATH
:source ~/.bashrc
Buat lingkungan Conda baru:
conda create -n tpu python=3.10
Aktifkan lingkungan Conda:
source activate tpu
Menyiapkan MaxDiffusion
Clone repositori MaxDiffusion dan buka direktori MaxDiffusion:
https://github.com/google/maxdiffusion.git && cd maxdiffusion
Beralih ke cabang
mlperf-4.1
:git checkout mlperf4.1
Instal MaxDiffusion:
pip install -e .
Instal dependensi:
pip install -r requirements.txt
Menginstal JAX:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Buat gambar
Tetapkan variabel lingkungan untuk mengonfigurasi runtime TPU:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
Buat gambar menggunakan perintah dan konfigurasi yang ditentukan di
src/maxdiffusion/configs/base_xl.yml
:python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
Pembersihan
Hapus TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
Tutorial pelatihan
Bagian berikut memberikan tutorial untuk melatih MaxText,
Model MaxDiffusion dan PyTorch di TPU v6e.
MaxText dan MaxDiffusion
Bagian berikut membahas siklus proses pelatihan model MaxText dan MaxDiffusion.
Secara umum, langkah-langkah tingkat tingginya adalah:
- Build image dasar workload.
- Jalankan workload Anda menggunakan XPK.
- Buat perintah pelatihan untuk beban kerja.
- Deploy workload.
- Ikuti beban kerja dan lihat metrik.
- Hapus workload XPK jika tidak diperlukan.
- Hapus cluster XPK jika tidak diperlukan lagi.
Mem-build image dasar
Instal MaxText atau MaxDiffusion dan build image Docker:
Clone repositori yang ingin Anda gunakan dan ubah ke direktori untuk repositori:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Konfigurasikan Docker agar menggunakan Google Cloud CLI:
gcloud auth configure-docker
Build image Docker menggunakan perintah berikut atau menggunakan JAX Stable Stack. Untuk mengetahui informasi selengkapnya tentang JAX Stable Stack, lihat Mem-build image Docker dengan JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
Jika Anda meluncurkan beban kerja dari mesin yang tidak memiliki image yang di-build secara lokal, upload image:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Mem-build image Docker dengan JAX Stable Stack
Anda dapat mem-build image Docker MaxText dan MaxDiffusion menggunakan image dasar JAX Stable Stack.
JAX Stable Stack menyediakan lingkungan yang konsisten untuk MaxText dan MaxDiffusion
dengan memaketkan JAX dengan paket inti seperti orbax
, flax
, dan optax
, beserta
libtpu.so yang memenuhi syarat yang mendorong utilitas program
TPU dan alat penting lainnya. Library ini diuji untuk memastikan kompatibilitas, menyediakan fondasi yang stabil untuk mem-build dan menjalankan MaxText dan MaxDiffusion, serta menghilangkan potensi konflik karena versi paket yang tidak kompatibel.
JAX Stable Stack menyertakan libtpu.so yang dirilis sepenuhnya dan memenuhi syarat, library inti yang mendorong kompilasi, eksekusi, dan konfigurasi jaringan ICI program TPU. Rilis libtpu menggantikan build harian yang sebelumnya digunakan oleh JAX, dan memastikan fungsi komputasi XLA yang konsisten di TPU dengan pengujian kualifikasi tingkat PJRT di IR HLO/StableHLO.
Untuk mem-build image Docker MaxText dan MaxDiffusion dengan JAX Stable Stack, saat
Anda menjalankan skrip docker_build_dependency_image.sh
, tetapkan variabel MODE
ke stable_stack
dan tetapkan variabel BASEIMAGE
ke image dasar
yang ingin Anda gunakan.
Contoh berikut menentukan
us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
sebagai
gambar dasar:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Untuk mengetahui daftar image dasar JAX Stable Stack yang tersedia, lihat Image JAX Stable Stack di Artifact Registry.
Menjalankan workload menggunakan XPK
Tetapkan variabel lingkungan berikut jika Anda tidak menggunakan nilai default yang ditetapkan oleh MaxText atau MaxDiffusion:
BASE_OUTPUT_DIR=gs://YOUR_BUCKET PER_DEVICE_BATCH_SIZE=2 NUM_STEPS=30 MAX_TARGET_LENGTH=8192
Build skrip model Anda untuk disalin sebagai perintah pelatihan di langkah berikutnya. Jangan jalankan skrip model terlebih dahulu.
MaxText
MaxText adalah LLM open source berperforma tinggi dan sangat skalabel yang ditulis dalam Python dan JAX murni serta menargetkan TPU dan GPU Google Cloud untuk pelatihan dan inferensi.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=$BASE_OUTPUT_DIR \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS}" # attention='dot_product'"
Gemma2
Gemma adalah serangkaian model bahasa besar (LLM) dengan bobot terbuka yang dikembangkan oleh Google DeepMind, berdasarkan riset dan teknologi Gemini.
# Requires v6e-256 python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral adalah model AI tercanggih yang dikembangkan oleh Mistral AI, yang memanfaatkan arsitektur campuran ahli (MoE) yang jarang.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama adalah serangkaian model bahasa besar (LLM) dengan bobot terbuka yang dikembangkan oleh Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion adalah kumpulan implementasi referensi dari berbagai model difusi laten yang ditulis dalam Python murni dan JAX yang berjalan di perangkat XLA, termasuk Cloud TPU dan GPU. Stable Diffusion adalah model teks ke gambar laten yang menghasilkan gambar fotorealistik dari input teks apa pun.
Anda perlu menginstal cabang tertentu untuk menjalankan MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Skrip pelatihan:
cd maxdiffusion && OUT_DIR=${your_own_bucket} python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \ run_name=v6e-sd2 \ split_head_dim=True \ attention=flash \ train_new_unet=false \ norm_num_groups=16 \ output_dir=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ [dcn_data_parallelism=2] \ enable_profiler=True \ skip_first_n_steps_for_profiler=95 \ max_train_steps=${NUM_STEPS} ] write_metrics=True'
Jalankan model menggunakan skrip yang Anda buat pada langkah sebelumnya. Anda harus menentukan flag
--base-docker-image
untuk menggunakan image dasar MaxText atau menentukan flag--docker-image
dan image yang ingin Anda gunakan.Opsional: Anda dapat mengaktifkan logging debug dengan menyertakan flag
--enable-debug-logs
. Untuk informasi selengkapnya, lihat Men-debug JAX di MaxText.Opsional: Anda dapat membuat Eksperimen Vertex AI untuk mengupload data ke Vertex AI TensorBoard dengan menyertakan tanda
--use-vertex-tensorboard
. Untuk mengetahui informasi selengkapnya, lihat Memantau JAX di MaxText menggunakan Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \ --tpu-type=ACCELERATOR_TYPE \ --num-slices=NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command YOUR_MODEL_SCRIPT
Ganti variabel berikut:
- CLUSTER_NAME: Nama cluster XPK Anda.
- ACCELERATOR_TYPE: Versi dan ukuran TPU Anda. Contoh,
v6e-256
. - NUM_SLICES: Jumlah slice TPU.
- YOUR_MODEL_SCRIPT: Skrip model yang akan dieksekusi sebagai perintah pelatihan.
Output-nya menyertakan link untuk mengikuti beban kerja Anda, mirip dengan berikut ini:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Buka link dan klik tab Logs untuk melacak beban kerja Anda secara real time.
Men-debug JAX di MaxText
Gunakan perintah XPK tambahan untuk mendiagnosis alasan cluster atau beban kerja tidak berjalan:
- Daftar workload XPK
- XPK inspector
- Aktifkan logging panjang di log beban kerja menggunakan flag
--enable-debug-logs
saat Anda membuat beban kerja XPK.
Memantau JAX di MaxText menggunakan Vertex AI
Lihat data skalar dan profil melalui TensorBoard terkelola Vertex AI.
- Tingkatkan permintaan pengelolaan resource (CRUD) untuk zona yang Anda gunakan dari 600 menjadi 5.000. Hal ini mungkin bukan masalah untuk beban kerja kecil yang menggunakan kurang dari 16 VM.
Instal dependensi seperti
cloud-accelerator-diagnostics
untuk Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Buat cluster XPK menggunakan flag
--create-vertex-tensorboard
, seperti yang didokumentasikan dalam Membuat Vertex AI TensorBoard. Anda juga dapat menjalankan perintah ini di cluster yang ada.Buat eksperimen Vertex AI saat menjalankan beban kerja XPK menggunakan tanda
--use-vertex-tensorboard
dan tanda--experiment-name
opsional. Untuk mengetahui daftar lengkap langkah-langkahnya, lihat Membuat Vertex AI Experiment untuk mengupload data ke Vertex AI TensorBoard.
Log menyertakan link ke Vertex AI TensorBoard, mirip dengan hal berikut:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Anda juga dapat menemukan link Vertex AI TensorBoard di Konsol Google Cloud. Buka Vertex AI Experiments di konsol Google Cloud. Pilih wilayah yang sesuai dari drop-down.
Direktori TensorBoard juga ditulis ke bucket Cloud Storage yang Anda tentukan dengan ${BASE_OUTPUT_DIR}
.
Menghapus workload XPK
Gunakan perintah xpk workload delete
untuk menghapus satu atau beberapa beban kerja berdasarkan awalan tugas atau status tugas. Perintah ini
mungkin berguna jika Anda mengirim beban kerja XPK yang tidak perlu lagi dijalankan, atau jika
Anda memiliki tugas yang macet di antrean.
Menghapus cluster XPK
Gunakan perintah
xpk cluster delete
untuk menghapus cluster:
python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID
Llama dan PyTorch
Tutorial ini menjelaskan cara melatih model Llama menggunakan PyTorch/XLA di TPU v6e menggunakan set data WikiText. Selain itu, pengguna dapat mengakses rekrip model TPU PyTorch sebagai image docker di sini.
Penginstalan
Instal fork
pytorch-tpu/transformers
dari
Hugging Face Transformers dan dependensi di lingkungan virtual:
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate
Menyiapkan konfigurasi model
Perintah pelatihan di bagian berikutnya, Mem-build skrip model menggunakan dua file konfigurasi JSON untuk menentukan parameter model dan konfigurasi FSDP (Fully Sharded Data Parallel). Sharding FSDP digunakan untuk bobot model agar sesuai dengan ukuran batch yang lebih besar saat pelatihan. Saat berlatih dengan model yang lebih kecil, mungkin cukup menggunakan paralelisme data dan mereplikasi bobot di setiap perangkat. Lihat Panduan Pengguna SPMD PyTorch/XLA untuk mengetahui detail selengkapnya tentang cara melakukan shard tensor di seluruh perangkat di PyTorch/XLA.
Buat file konfigurasi parameter model. Berikut adalah konfigurasi parameter model untuk Llama3-8B. Untuk model lainnya, temukan konfigurasi di Hugging Face. Misalnya, lihat konfigurasi Llama2-7B.
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 }
Buat file konfigurasi FSDP:
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true }
Lihat FSDPv2 untuk mengetahui detail selengkapnya tentang FSDP.
Upload file konfigurasi ke VM TPU menggunakan perintah berikut:
gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Anda juga dapat membuat file konfigurasi di direktori kerja saat ini dan menggunakan tanda
--base-docker-image
di XPK.
Membuat skrip model
Build skrip model Anda, yang menentukan file konfigurasi parameter model menggunakan
tanda --config_name
dan file konfigurasi FSDP menggunakan tanda --fsdp_config
.
Anda akan menjalankan skrip ini di TPU di bagian berikutnya, Menjalankan
model. Jangan jalankan skrip model terlebih dahulu.
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 PROFILE_EPOCH=0 PROFILE_STEP=3 PROFILE_DURATION_MS=100000 PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 8 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/config-8B.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp_config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20
Menjalankan model
Jalankan model menggunakan skrip yang Anda buat di langkah sebelumnya, Mem-build skrip model.
Jika menggunakan VM TPU host tunggal (seperti v6e-4
), Anda dapat menjalankan perintah pelatihan langsung di VM TPU. Jika Anda menggunakan VM TPU multi-host, gunakan
perintah berikut untuk menjalankan skrip secara bersamaan di semua host:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=YOUR_COMMAND
Memecahkan masalah PyTorch/XLA
Jika Anda menetapkan variabel opsional untuk proses debug di bagian sebelumnya,
profil untuk model akan disimpan di lokasi yang ditentukan oleh
variabel PROFILE_LOGDIR
. Anda dapat mengekstrak file xplane.pb
yang disimpan
di lokasi ini dan menggunakan tensorboard
untuk melihat profil di browser
Anda menggunakan petunjuk
TensorBoard
Jika PyTorch/XLA tidak berperforma seperti yang diharapkan, lihat panduan pemecahan masalah,
yang berisi saran untuk men-debug, membuat profil, dan mengoptimalkan model Anda.
Tutorial DLRM DCN v2
Tutorial ini menunjukkan cara melatih model DLRM DCN v2 di TPU v6e.
Jika Anda menjalankan di multi-host, reset tpu-runtime
dengan versi
TensorFlow yang sesuai dengan menjalankan
perintah berikut. Jika menjalankan di satu host, Anda tidak perlu
menjalankan dua perintah berikut.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
SSH ke worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Menetapkan nama TPU
export TPU_NAME=${TPU_NAME}
Menjalankan DLRM v2
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Jalankan script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Flag berikut diperlukan untuk menjalankan workload rekomendasi (DLRM DCN):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Hasil benchmark
Bagian berikut berisi hasil benchmark untuk DLRM DCN v2 dan MaxDiffusion di v6e.
DLRM DCN v2
Skrip pelatihan DLRM DCN v2 dijalankan pada skala yang berbeda. Lihat throughput dalam tabel berikut.
v6e-64 | v6e-128 | v6e-256 | |
Langkah-langkah pelatihan | 7.000 | 7.000 | 7.000 |
Ukuran batch global | 131072 | 262144 | 524288 |
Throughput (contoh/dtk) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Kami menjalankan skrip pelatihan untuk MaxDiffusion di v6e-4, v6e-16, dan 2xv6e-16. Lihat throughput dalam tabel berikut.
v6e-4 | v6e-16 | Dua v6e-16 | |
Langkah-langkah pelatihan | 0,069 | 0,073 | 0,13 |
Ukuran batch global | 8 | 32 | 64 |
Throughput (contoh/dtk) | 115,9 | 438,4 | 492,3 |
Koleksi
v6e memperkenalkan fitur baru bernama koleksi untuk kepentingan pengguna yang menjalankan beban kerja penayangan. Fitur koleksi hanya berlaku untuk v6e.
Koleksi memungkinkan Anda menunjukkan kepada Google Cloud node TPU mana yang membentuk bagian dari workload penayangan. Hal ini memungkinkan infrastruktur Google Cloud yang mendasarinya untuk membatasi dan menyederhanakan gangguan yang dapat diterapkan ke workload pelatihan dalam kursus operasi normal.
Menggunakan koleksi dari Cloud TPU API
Koleksi satu host di Cloud TPU API adalah resource yang diantrekan dengan
tanda khusus (--workload-type = availability-optimized
) yang ditetapkan untuk
menunjukkan ke infrastruktur yang mendasarinya bahwa resource tersebut dimaksudkan untuk digunakan untuk
menayangkan workload.
Perintah berikut menyediakan koleksi host tunggal menggunakan Cloud TPU API:
gcloud alpha compute tpus queued-resources create COLLECTION_NAME \ --project=project name \ --zone=zone name \ --accelerator-type=accelerator type \ --node-count=number of nodes \ --workload-type=availability-optimized
Memantau dan membuat profil
Cloud TPU v6e mendukung pemantauan dan pembuatan profil menggunakan metode yang sama dengan Cloud TPU generasi sebelumnya. Untuk mengetahui informasi selengkapnya tentang pemantauan, lihat Memantau VM TPU.