Exécuter du code JAX sur des tranches de pods Cloud TPU
Une fois que votre code JAX s'exécute sur une seule carte TPU, vous pouvez augmenter la capacité en l'exécutant sur une tranche de pod TPU. Les tranches de pod TPU sont des cartes de TPU interconnectées sur des connexions réseau haut débit dédiées. Ce document est une introduction à l'exécution de code JAX sur des tranches de pods TPU. Pour des informations plus détaillées, consultez la section Utiliser JAX dans des environnements multihôtes et multiprocessus.
Si vous souhaitez utiliser une authentification NFS installée pour le stockage des données, vous devez définir OS Login pour tous VM TPU dans la tranche de pod. Pour en savoir plus, consultez Utiliser un NFS pour le stockage de donnéesCréer une tranche de pod TPU
Avant d'exécuter les commandes de ce document, assurez-vous d'avoir suivi les instructions de la section Configurer un compte et un projet Cloud TPU. Exécutez les commandes suivantes sur votre ordinateur local.
Créez une tranche de pod TPU à l'aide de la commande gcloud
. Par exemple, pour créer
Pour la tranche de pod v4-32, exécutez la commande suivante:
$ gcloud compute tpus tpu-vm create tpu-name \
--zone=us-central2-b \
--accelerator-type=v4-32 \
--version=tpu-ubuntu2204-base
Installer JAX sur la tranche de pod
Après avoir créé la tranche de pod TPU, vous devez installer JAX sur tous les hôtes de la tranche de pod TPU. Vous pouvez installer JAX sur tous les hôtes à l'aide d'une seule commande avec l'option --worker=all
:
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Exécuter du code JAX sur la tranche de pod
Pour exécuter du code JAX sur une tranche de pod TPU, vous devez exécuter le code sur chaque hôte de la tranche de pod TPU. L'appel jax.device_count()
ne répond plus tant qu'il n'est pas
sur chaque hôte de la tranche de pod. L'exemple suivant montre comment
exécuter un calcul JAX simple sur une tranche de pod TPU.
Préparer le code
Vous devez disposer de gcloud
344.0.0 ou version ultérieure (pour le paramètre
scp
).
Utilisez gcloud --version
pour vérifier votre version de gcloud
.
Si nécessaire, exécutez gcloud components upgrade
.
Créez un fichier nommé example.py
avec le code suivant:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copier example.py
sur toutes les VM de nœuds de calcul TPU dans la tranche de pod
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
--worker=all \
--zone=us-central2-b
Si vous n'avez jamais utilisé la commande scp
, il est possible qu'une erreur
s'affiche comme suit:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
Pour résoudre l'erreur, exécutez la commande ssh-add
comme indiqué dans la
message d'erreur et réexécutez la commande.
Exécuter le code sur la tranche du pod
Lancez le programme example.py
sur chaque VM :
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="python3 example.py"
Résultat (produit avec une tranche de pod v4-32):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Effectuer un nettoyage
Lorsque vous avez terminé, vous pouvez supprimer vos ressources de VM TPU à l'aide de la commande gcloud
:
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=us-central2-b