Introduzione a Trillium (v6e)

v6e viene utilizzato per fare riferimento a Trillium in questa documentazione, nell'API TPU e nei log. v6e rappresenta la sesta generazione di TPU di Google.

Con 256 chip per pod, la versione v6e condivide molte somiglianze con la versione v5e. Questo sistema è ottimizzato per essere il prodotto di maggior valore per l'addestramento, la messa a punto e la pubblicazione di transformer, conversione di testo in immagini e reti neurali convoluzionali (CNN).

Architettura di sistema v6e

Per informazioni sulla configurazione di Cloud TPU, consulta la documentazione della versione 6e.

Questo documento si concentra sulla procedura di configurazione per l'addestramento dei modelli utilizzando i framework JAX, PyTorch o TensorFlow. Con ogni framework, puoi eseguire il provisioning delle TPU utilizzando risorse in coda o Google Kubernetes Engine (GKE). La configurazione di GKE può essere eseguita utilizzando i comandi XPK o GKE.

Preparare un progetto Google Cloud

  1. Accedi al tuo Account Google. Se non l'hai ancora fatto, registrati per creare un nuovo account.
  2. Nella console Google Cloud, seleziona o create un progetto Cloud dalla pagina del selettore di progetti.
  3. Abilita la fatturazione per il tuo progetto Google Cloud. La fatturazione è obbligatoria per tutto l'utilizzo di Google Cloud.
  4. Installa i componenti gcloud alpha.
  5. Esegui il seguente comando per installare la versione più recente dei componenti gcloud.

    gcloud components update
    
  6. Abilita l'API TPU tramite il seguente comando gcloud in Cloud Shell. Puoi anche attivarlo dalla console Google Cloud.

    gcloud services enable tpu.googleapis.com
    
  7. Abilita le autorizzazioni con l'account di servizio TPU per l'API Compute Engine

    Gli account di servizio consentono al servizio Cloud TPU di accedere ad altri servizi Google Cloud. Un account di servizio gestito dall'utente è una best practice di Google Cloud. Segui queste guide per create e concedere i ruoli. Sono necessari i seguenti ruoli:

    • TPU Admin
    • Amministratore Storage
    • Logs Writer
    • Monitoring Metric Writer

    a. Configura le autorizzazioni XPK con il tuo account utente per GKE: XPK.

  8. Crea variabili di ambiente per l'ID progetto e la zona.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Crea un'identità di servizio per la VM TPU.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Capacità sicura

Contatta il team di vendita/account di assistenza Cloud TPU per richiedere una quota TPU e per rispondere a eventuali domande sulla capacità.

Esegui il provisioning dell'ambiente Cloud TPU

È possibile eseguire il provisioning e la gestione delle TPU v6e con GKE, con GKE e XPK (uno strumento CLI wrapper su GKE) o come risorse in coda.

Prerequisiti

  • Verifica che il tuo progetto disponga di una quota TPUS_PER_TPU_FAMILY sufficiente, che specifica il numero massimo di chip a cui puoi accedere nel tuo progetto Google Cloud.
  • La versione 6e è stata testata con la seguente configurazione:
    • Python 3.10 o versioni successive
    • Versioni software Nightly:
      • a notte JAX 0.4.32.dev20240912
      • nightly LibTPU 0.1.dev20240912+nightly
    • Versioni software stabili:
      • JAX + JAX Lib della versione 0.4.35
  • Verifica che il tuo progetto disponga di una quota TPU sufficiente per:
    • Quota VM TPU
    • Quota di indirizzi IP
    • Quota Hyperdisk bilanciato
  • Autorizzazioni del progetto per gli utenti

Variabili di ambiente

In Cloud Shell, crea le seguenti variabili di ambiente:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descrizioni dei flag dei comandi

Variabile Descrizione
NODE_ID L'ID assegnato dall'utente della TPU che viene creato quando viene allocata la richiesta di risorsa in coda.
PROJECT_ID Nome del progetto Google Cloud. Utilizza un progetto esistente o creane uno nuovo su
ZONA Consulta il documento Regioni e zone TPU per le zone supportate.
ACCELERATOR_TYPE Consulta la sezione Tipi di acceleratore.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Si tratta dell'indirizzo email del tuo account di servizio che puoi trovare in Google Cloud Console -> IAM -> Account di servizio

Ad esempio: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES Il numero di sezioni da creare (obbligatorio solo per Multislice)
QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.
VALID_DURATION La durata di validità della richiesta di risorse in coda.
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Ottimizzazioni delle prestazioni di rete

Per ottenere le migliori prestazioni,utilizza una rete con 8896 MTU (unità massima di trasmissione).

Per impostazione predefinita, un Virtual Private Cloud (VPC) fornisce solo un MTU di 1460 byte,il che comporterà prestazioni di rete non ottimali. Puoi impostare l'MTU di una rete VPC su qualsiasi valore compreso tra 1300 e 8896 byte (inclusi). Le dimensioni MTU personalizzate comuni sono 1500 byte (Ethernet standard) o 8896 byte (il massimo possibile). Per saperne di più, consulta Dimensioni MTU valide per le reti VPC.

Per saperne di più sulla modifica dell'impostazione MTU per una rete esistente o predefinita, consulta Modificare l'impostazione MTU di una rete VPC.

L'esempio seguente crea una rete con 8896 MTU

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Utilizzo di più NIC (opzione per il multislice)

Le seguenti variabili di ambiente sono necessarie per una subnet secondaria quando utilizzi un ambiente Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Utilizza i seguenti comandi per creare il routing IP personalizzato per la rete e la subnet.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Una volta creato uno slice multirete, puoi verificare che entrambe le NIC siano in uso eseguendo --command ifconfig nell'ambito del carico di lavoro XPK. Poi, controlla l'output stampato del carico di lavoro XPK nei log della console Cloud e verifica che sia eth0 che eth1 abbiano mtu=8896.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verifica che sia eth0 che eth1 abbiano mtu=8896. un modo per verificare che sia in esecuzione il multi-NIC è eseguire il comando --command "ifconfig" nell'ambito del carico di lavoro XPK. Poi controlla l'output stampato del carico di lavoro xpk nei log della console cloud e verifica che sia eth0 che eth1 abbiano mtu=8896.

Impostazioni TCP migliorate

Per le TPU create utilizzando l'interfaccia delle risorse in coda, puoi eseguire il seguente comando per migliorare le prestazioni della rete modificando le impostazioni TCP predefinite per rto_min e quickack.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Provisioning con risorse in coda (API Cloud TPU)

È possibile eseguire il provisioning della capacità utilizzando il comando create queued-resource.

  1. Crea una richiesta di risorsa TPU in coda.

    Il flag --reserved è necessario solo per le risorse riservate, non per le risorse on demand.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    Se la richiesta di risorse in coda viene creata correttamente, lo stato nel campo "response" sarà "WAITING_FOR_RESOURCES" o "FAILED". Se la richiesta di risorse in coda è nello stato "WAITING_FOR_RESOURCES", la risorsa in coda è stata inserita in coda e verrà eseguita il provisioning quando sarà disponibile una capacità TPU sufficiente. Se la richiesta di risorse in coda è in stato "FAILED", il motivo dell'errore sarà nell'output. La richiesta di risorse in coda scadrà se non viene eseguito il provisioning di un v6e entro la durata specificata e lo stato diventa "FAILED". Per ulteriori informazioni, consulta la documentazione pubblica relativa alle risorse in coda.

    Quando la richiesta di risorse in coda è in stato "ACTIVE", puoi collegarti alle VM TPU tramite SSH. Utilizza i comandi list o describe per eseguire query sullo stato della risorsa in coda.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Quando la risorsa in coda è nello stato "ATTIVO", l'output è simile al seguente:

      state:
       state: ACTIVE
    
  2. Gestisci le VM TPU. Per le opzioni di gestione delle VM TPU, consulta la sezione sulla gestione delle VM TPU.

  3. Connettiti alle VM TPU tramite SSH

    Puoi installare i binari su ogni VM TPU nella sezione TPU ed eseguire il codice. Consulta la sezione Tipi di VM per determinare quante VM avrà il tuo slice.

    Per installare i binari o eseguire codice, puoi utilizzare SSH per connetterti a una VM utilizzando il comando tpu-vm ssh.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Per utilizzare SSH per connetterti a una VM specifica, utilizza il flag --worker che segue un indice basato su 0:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Se le forme delle sezioni sono superiori a 8 chip, avrai più VM in una sezione. In questo caso, utilizza i parametri --worker=all e --command nel comando gcloud alpha compute tpus tpu-vm ssh per eseguire un comando su tutte le VM contemporaneamente. Ad esempio:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Elimina una risorsa in coda

    Elimina una risorsa in coda alla fine della sessione o rimuovi le richieste di risorse in coda nello stato "FAILED". Per eliminare una risorsa in coda, elimina il segmento e poi la richiesta della risorsa in coda in due passaggi:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Utilizzo di GKE con v6e

Se utilizzi i comandi GKE con la versione 6e, puoi utilizzare i comandi Kubernetes o XPK per eseguire il provisioning delle TPU e addestrare o pubblicare i modelli. Consulta la sezione Pianificare le TPU in GKE per scoprire come utilizzare GKE con le TPU e la versione v6e.

Configurazione del framework

Questa sezione descrive la procedura di configurazione generale per l'addestramento dei modelli di ML utilizzando i framework JAX, PyTorch o TensorFlow. Puoi eseguire il provisioning delle TPU utilizzando le risorse in coda o GKE. La configurazione di GKE può essere eseguita utilizzando i comandi XPK o Kubernetes.

Configura JAX utilizzando le risorse in coda

Installa JAX su tutte le VM TPU del tuo o dei tuoi slice contemporaneamente utilizzando gcloud alpha compute tpus tpu-vm ssh. Per Multislice, aggiungi --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Puoi eseguire il seguente codice Python per controllare quanti core TPU sono disponibili nel tuo slice e per verificare che tutto sia installato correttamente (gli output mostrati qui sono stati prodotti con un slice v6e-16):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

L'output è simile al seguente:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() mostra il numero totale di chip nel determinato slice. jax.local_device_count() indica il numero di chip accessibili da una singola VM in questo slice.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Risoluzione dei problemi di configurazione di JAX

Un suggerimento generale è abilitare la registrazione dettagliata nel manifest del carico di lavoro GKE. Quindi, fornisci i log all'assistenza GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Messaggi di errore

no endpoints available for service 'jobset-webhook-service'

Questo errore indica che il jobset non è stato installato correttamente. Controlla se i pod Kubernetes del deployment jobset-controller-manager sono in esecuzione. Per ulteriori informazioni, consulta la documentazione sulla risoluzione dei problemi relativi a JobSet.

TPU initialization failed: Failed to connect

Assicurati che la versione del nodo GKE sia 1.30.4-gke.1348000 o successiva (GKE 1.31 non è supportato).

Configurazione per PyTorch

Questa sezione descrive come iniziare a utilizzare PJRT nella versione 6e con PyTorch/XLA. Python 3.10 è la versione consigliata.

Configurare PyTorch utilizzando GKE con XPK

Puoi utilizzare il seguente contenitore Docker con XPK in cui sono già installate le dipendenze di PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Per creare un carico di lavoro XPK, utilizza il seguente comando:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

L'utilizzo di --base-docker-image crea una nuova immagine Docker con la directory di lavoro corrente integrata nel nuovo Docker.

Configurare PyTorch utilizzando le risorse in coda

Segui questi passaggi per installare PyTorch utilizzando le risorse in coda e per eseguire un piccolo script sulla versione 6e.

Installa le dipendenze utilizzando SSH per accedere alle VM.

Per Multislice, aggiungi --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Migliora le prestazioni dei modelli con allocazioni significative e frequenti

Per i modelli con allocazioni frequenti e di grandi dimensioni, abbiamo osservato che l'utilizzo di tcmalloc migliora notevolmente le prestazioni rispetto all'implementazione predefinita di malloc, pertanto il valore malloc predefinito utilizzato sulla VM TPU è tcmalloc. Tuttavia, a seconda del tuo caricamento di lavoro (ad esempio, con DLRM che ha allocazioni molto grandi per le sue tabelle di embedding), tcmalloc potrebbe causare un rallentamento, nel qual caso potresti provare a reimpostare la seguente variabile utilizzando malloc predefinito:

unset LD_PRELOAD

Utilizza uno script Python per eseguire un calcolo sulla VM v6e:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Viene generato un output simile al seguente:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configurazione per TensorFlow

Per la versione Anteprima pubblica 6e, è supportata solo la versione del runtime tf-nightly.

Puoi reimpostare tpu-runtime con la versione di TensorFlow compatibile con v6e eseguendo i seguenti comandi:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Utilizza SSH per accedere a worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installa TensorFlow su worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Esporta la variabile di ambiente TPU_NAME:

export TPU_NAME=v6e-16

Puoi eseguire il seguente script Python per verificare quanti core TPU sono disponibili nel tuo slice e per verificare che tutto sia installato correttamente (gli output mostrati sono stati generati con un slice v6e-16):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

L'output è simile al seguente:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e con SkyPilot

Puoi utilizzare TPU v6e con SkyPilot. Segui la procedura riportata di seguito per aggiungere a SkyPilot le informazioni su prezzi e località relative a v6e.

  1. Aggiungi quanto segue alla fine di ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Specifica le seguenti risorse in un file YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Avvia un cluster con TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Connettiti alla TPU v6e tramite SSH: ssh tpu_v6

Tutorial sull'inferenza

Le sezioni seguenti forniscono tutorial per l'erogazione di modelli MaxText e PyTorch utilizzando JetStream, nonché per l'erogazione di modelli MaxDiffusion su TPU v6e.

MaxText su JetStream

Questo tutorial mostra come utilizzare JetStream per pubblicare modelli MaxText (JAX) su TPU v6e. JetStream è un motore ottimizzato per la velocità effettiva e la memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) su dispositivi XLA (TPU). In questo tutorial eseguirai il benchmark di inferenza per il modello Llama2-7B.

Prima di iniziare

  1. Crea una TPU v6e con 4 chip:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connettiti alla TPU tramite SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Esegui il tutorial

Per configurare JetStream e MaxText, convertire i checkpoint del modello ed eseguire il benchmark di inferenza, segui le istruzioni nel repository GitHub.

Esegui la pulizia

Elimina la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM su PyTorch TPU

Di seguito è riportato un semplice tutorial che mostra come iniziare a utilizzare vLLM su VM TPU. Per un esempio di best practice di implementazione di vLLM su Trillium in produzione, pubblicheremo una guida utente di GKE nei prossimi giorni (continua a seguirci).

Prima di iniziare

  1. Crea una TPU v6e con 4 chip:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Descrizioni dei flag dei comandi

    Variabile Descrizione
    NODE_ID L'ID assegnato dall'utente della TPU che viene creato quando viene allocata la richiesta di risorsa in coda.
    PROJECT_ID Nome del progetto Google Cloud. Utilizza un progetto esistente o creane uno nuovo all'indirizzo
    ZONA Consulta il documento Regioni e zone TPU per le zone supportate.
    ACCELERATOR_TYPE Consulta la sezione Tipi di acceleratore.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Si tratta dell'indirizzo email del tuo account di servizio che puoi trovare in Google Cloud Console -> IAM -> Account di servizio

    Ad esempio: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. Connettiti alla TPU tramite SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

Configurare vLLM su TPU

  1. Clona il repository vLLM e vai alla directory vLLM:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Ripulisci i pacchetti torch e torch-xla esistenti:

    pip uninstall torch torch-xla -y
    
  3. Installa PyTorch e PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Installa JAX e Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Installa altre dipendenze di compilazione:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Ottieni l'accesso al modello

Devi firmare il contratto di consenso per utilizzare la famiglia di modelli Llama3 nel repo HuggingFace

Genera un nuovo token Abbracciamento se non ne hai già uno:

  1. Fai clic su Il tuo profilo > Impostazioni > Token di accesso.
  2. Seleziona Nuovo token.
  3. Specifica un nome a tua scelta e un ruolo di almeno Read.
  4. Seleziona Genera un token.
  5. Copia il token generato negli appunti, impostalo come variabile di ambiente e autenticati con huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Scaricare i dati di benchmarking

  1. Crea una directory /data e scarica il set di dati ShareGPT da Hugging Face.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

Avvia il server vLLM

Il seguente comando scarica i pesi del modello da Hugging Face Model Hub nella directory /tmp della VM TPU, precompila una serie di forme di input e scrive la compilazione del modello in ~/.cache/vllm/xla_cache.

Per ulteriori dettagli, consulta la documentazione di vLLM.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

Esegui benchmark vLLM

Esegui lo script di benchmarking vLLM:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Esegui la pulizia

Elimina la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch su JetStream

Questo tutorial mostra come utilizzare JetStream per eseguire il servizio di modelli PyTorch su TPU v6e. JetStream è un motore ottimizzato per la velocità effettiva e la memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) su dispositivi XLA (TPU). In questo tutorial eseguirai il benchmark di inferenza per il modello Llama2-7B.

Prima di iniziare

  1. Crea una TPU v6e con 4 chip:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connettiti alla TPU tramite SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Esegui il tutorial

Per configurare JetStream-PyTorch, convertire i checkpoint del modello ed eseguire il benchmark di inferenza, segui le istruzioni nel repository GitHub.

Esegui la pulizia

Elimina la TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

Inferenza MaxDiffusion

Questo tutorial mostra come pubblicare modelli MaxDiffusion su TPU v6e. In questo tutorial, genererai immagini utilizzando il modello Stable Diffusion XL.

Prima di iniziare

  1. Crea una TPU v6e con 4 chip:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connettiti alla TPU tramite SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Creare un ambiente Conda

  1. Crea una directory per Miniconda:

    mkdir -p ~/miniconda3
  2. Scarica lo script di installazione di Miniconda:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Installa Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Rimuovi lo script di installazione di Miniconda:

    rm -rf ~/miniconda3/miniconda.sh
  5. Aggiungi Miniconda alla variabile PATH:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Ricarica ~/.bashrc per applicare le modifiche alla variabile PATH:

    source ~/.bashrc
  7. Crea un nuovo ambiente Conda:

    conda create -n tpu python=3.10
  8. Attiva l'ambiente Conda:

    source activate tpu

Configurare MaxDiffusion

  1. Clona il repository MaxDiffusion e vai alla directory MaxDiffusion:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Passa al branch mlperf-4.1:

    git checkout mlperf4.1
  3. Installa MaxDiffusion:

    pip install -e .
  4. Installa le dipendenze:

    pip install -r requirements.txt
  5. Installa JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Genera immagini

  1. Imposta le variabili di ambiente per configurare il runtime TPU:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Genera le immagini utilizzando il prompt e le configurazioni definite in src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Esegui la pulizia

Elimina la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Tutorial di formazione

Le seguenti sezioni forniscono tutorial per l'addestramento di MaxText.

Modelli MaxDiffusion e PyTorch su TPU v6e.

MaxText e MaxDiffusion

Le sezioni seguenti illustrano il ciclo di vita dell'addestramento dei modelli MaxText e MaxDiffusion.

In generale, i passaggi di alto livello sono:

  1. Crea l'immagine di base del workload.
  2. Esegui il carico di lavoro utilizzando XPK.
    1. Crea il comando di addestramento per il carico di lavoro.
    2. Esegui il deployment del carico di lavoro.
  3. Monitora il carico di lavoro e visualizza le metriche.
  4. Elimina il workload XPK se non è necessario.
  5. Elimina il cluster XPK quando non è più necessario.

Crea l'immagine di base

Installa MaxText o MaxDiffusion e crea l'immagine Docker:

  1. Clona il repository che vuoi utilizzare e passa alla directory del repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configura Docker in modo che utilizzi Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crea l'immagine Docker utilizzando il seguente comando o JAX Stable Stack. Per ulteriori informazioni su JAX Stable Stack, consulta Creare un'immagine Docker con JAX Stable Stack.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Se avvii il carico di lavoro da una macchina su cui l'immagine non è stata compilata localmente, caricala:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Crea un'immagine Docker con JAX Stable Stack

Puoi creare le immagini Docker MaxText e MaxDiffusion utilizzando l'immagine di base JAX Stable Stack.

JAX Stable Stack fornisce un ambiente coerente per MaxText e MaxDiffusion combinando JAX con pacchetti di base come orbax, flax e optax, oltre a una libtpu.so ben qualificata che gestisce le utilità di programmazione TPU e altri strumenti essenziali. Queste librerie vengono testate per garantire la compatibilità, fornendo una base stabile per la creazione e l'esecuzione di MaxText e MaxDiffusion ed eliminando potenziali conflitti dovuti a versioni del pacchetto incompatibili.

JAX Stable Stack include una libreria libtpu.so completamente rilasciata e qualificata, la libreria di base che gestisce la compilazione, l'esecuzione e la configurazione della rete ICI dei programmi TPU. La release libtpu sostituisce la build notturna precedentemente utilizzata da JAX e garantisce la funzionalità coerente dei calcoli XLA su TPU con test di qualificazione a livello di PJRT negli IR HLO/StableHLO.

Per creare l'immagine Docker di MaxText e MaxDiffusion con JAX Stable Stack, quando esegui lo script docker_build_dependency_image.sh, imposta la variabile MODE su stable_stack e la variabile BASEIMAGE sull'immagine di base che vuoi utilizzare.

L'esempio seguente specifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 come immagine di base:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Per un elenco delle immagini di base JAX Stable Stack disponibili, consulta Immagini JAX Stable Stack nel Registry degli elementi.

Esegui il carico di lavoro utilizzando XPK

  1. Imposta le seguenti variabili di ambiente se non utilizzi i valori predefiniti impostati da MaxText o MaxDiffusion:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Crea lo script del modello da copiare come comando di addestramento nel passaggio successivo. Non eseguire ancora lo script del modello.

    MaxText

    MaxText è un LLM open source ad alte prestazioni e altamente scalabile scritto in Python e JAX puro e che ha come target le TPU e le GPU di Google Cloud per l'addestramento e l'inferenza.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma è una famiglia di modelli linguistici di grandi dimensioni (LLM) con pesi aperti sviluppati da Google DeepMind, in base alla ricerca e alla tecnologia di Gemini.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral è un modello di IA all'avanguardia sviluppato da Mistral AI, che utilizza un'architettura sparse mixture-of-experts (MoE).

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama è una famiglia di modelli linguistici di grandi dimensioni (LLM) con pesi aperti sviluppati da Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion è una raccolta di implementazioni di riferimento di vari modelli di diffusione latente scritti in puro Python e JAX che vengono eseguiti su dispositivi XLA, tra cui Cloud TPU e GPU. Stable Diffusion è un modello latente di conversione di testo in immagini che genera immagini fotorealistiche da qualsiasi input di testo.

    Per eseguire MaxDiffusion, devi installare un ramo specifico:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Script di addestramento:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Esegui il modello utilizzando lo script creato nel passaggio precedente. Devi specificare il flag --base-docker-image per utilizzare l'immagine di base MaxText o specificare il flag --docker-image e l'immagine che vuoi utilizzare.

    (Facoltativo) Puoi attivare la registrazione di log di debug includendo il flag --enable-debug-logs. Per ulteriori informazioni, consulta Eseguire il debug di JAX su MaxText.

    (Facoltativo) Puoi creare un esperimento Vertex AI per caricare i dati in Vertex AI TensorBoard includendo il flag --use-vertex-tensorboard. Per ulteriori informazioni, consulta Monitorare JAX su MaxText utilizzando Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Sostituisci le seguenti variabili:

    • CLUSTER_NAME: il nome del cluster XPK.
    • ACCELERATOR_TYPE: la versione e le dimensioni della TPU. Ad esempio, v6e-256.
    • NUM_SLICES: il numero di sezioni TPU.
    • YOUR_MODEL_SCRIPT: lo script del modello da eseguire come comando di addestramento.

    L'output include un link per monitorare il carico di lavoro, simile al seguente:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Apri il link e fai clic sulla scheda Log per monitorare il tuo carico di lavoro in tempo reale.

Eseguire il debug di JAX su MaxText

Utilizza i comandi XPK supplementari per diagnosticare il motivo per cui il cluster o il carico di lavoro non è in esecuzione:

Monitorare JAX su MaxText utilizzando Vertex AI

Visualizza i dati scalari e di profilo tramite TensorBoard gestito da Vertex AI.

  1. Aumenta le richieste di gestione delle risorse (CRUD) per la zona in uso da 600 a 5000. Questo potrebbe non essere un problema per i carichi di lavoro di piccole dimensioni che utilizzano meno di 16 VM.
  2. Installa le dipendenze, ad esempio cloud-accelerator-diagnostics per Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea il tuo cluster XPK utilizzando il flag --create-vertex-tensorboard, come documentato in Creare Vertex AI TensorBoard. Puoi anche eseguire questo comando sui cluster esistenti.

  4. Crea l'esperimento Vertex AI quando esegui il tuo carico di lavoro XPK utilizzando il flag --use-vertex-tensorboard e il flag facoltativo --experiment-name. Per l'elenco completo dei passaggi, consulta Creare un esperimento Vertex AI per caricare i dati su Vertex AI TensorBoard.

I log includono un link a un Vertex AI TensorBoard, simile al seguente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Puoi trovare il link a Vertex AI TensorBoard anche nella console Google Cloud. Vai a Vertex AI Experiments nella console Google Cloud. Seleziona la regione appropriata dal menu a discesa.

La directory di TensorBoard viene scritta anche nel bucket Cloud Storage che hai specificato con ${BASE_OUTPUT_DIR}.

Eliminare i workload XPK

Utilizza il xpk workload delete comando per eliminare uno o più workload in base al prefisso o allo stato del job. Questo comando potrebbe essere utile se hai inviato carichi di lavoro XPK che non devono più essere eseguiti o se hai job bloccati nella coda.

Elimina il cluster XPK

Utilizza il comando xpk cluster delete per eliminare un cluster:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama e PyTorch

Questo tutorial descrive come addestrare i modelli Llama utilizzando PyTorch/XLA su TPU v6e utilizzando il set di dati WikiText. Inoltre, gli utenti possono accedere alle descrizioni dei modelli TPU PyTorch come immagini Docker qui.

Installazione

Installa il pytorch-tpu/transformers fork di Hugging Face Transformers e le dipendenze in un ambiente virtuale:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Configura le configurazioni del modello

Il comando di addestramento nella sezione successiva, Creare lo script del modello, utilizza due file di configurazione JSON per definire i parametri del modello e la configurazione FSDP (Fully Sharded Data Parallel). Lo sharding FSDP viene utilizzato per adattare i pesi del modello a un batch di dimensioni maggiori durante l'addestramento. Quando si esegue l'addestramento con modelli più piccoli, potrebbe essere sufficiente utilizzare il parallelismo dei dati e replicare i pesi su ogni dispositivo. Per ulteriori dettagli su come suddividere i tensori tra i dispositivi in PyTorch/XLA, consulta la Guida dell'utente di PyTorch/XLA SPMD.

  1. Crea il file di configurazione dei parametri del modello. Di seguito è riportata la configurazione del parametro del modello per Llama3-8B. Per altri modelli, trova la configurazione su Hugging Face. Per example, consulta la configurazione Llama2-7B.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Crea il file di configurazione FSDP:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Per ulteriori dettagli su FSDP, consulta FSDPv2.

  3. Carica i file di configurazione nelle VM TPU utilizzando il seguente comando:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    Puoi anche creare i file di configurazione nella directory di lavoro attuale e utilizzare il flag --base-docker-image in XPK.

Creare lo script del modello

Crea lo script del modello, specificando il file di configurazione dei parametri del modello utilizzando il flag --config_name e il file di configurazione FSDP utilizzando il flag --fsdp_config. Eseguirai questo script sulla TPU nella sezione successiva, Esegui il modello. Non eseguire ancora lo script del modello.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Esegui il modello

Esegui il modello utilizzando lo script che hai creato nel passaggio precedente, Creare lo script del modello.

Se utilizzi una VM TPU a un solo host (ad esempio v6e-4), puoi eseguire il comando di addestramento direttamente sulla VM TPU. Se utilizzi una VM TPU multi-host, utilizza il seguente comando per eseguire lo script contemporaneamente su tutti gli host:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Risoluzione dei problemi relativi a PyTorch/XLA

Se imposti le variabili facoltative per il debug nella sezione precedente, il profilo del modello verrà archiviato nella posizione specificata dalla variabile PROFILE_LOGDIR. Puoi estrarre il file xplane.pb memorizzato in questa posizione e utilizzare tensorboard per visualizzare i profili nel browser seguendo le istruzioni di TensorBoard. Se PyTorch/XLA non funziona come previsto, consulta la guida alla risoluzione dei problemi, che contiene suggerimenti per il debug, il profiling e l'ottimizzazione dei modelli.

Tutorial su DLRM DCN v2

Questo tutorial mostra come addestrare il modello DLRM DCN v2 su TPU v6e.

Se esegui l'operazione su più host, reimposta tpu-runtime con la versione di TensorFlow appropriata eseguendo il seguente comando. Se esegui l'operazione su un singolo host, non è necessario eseguire i due comandi riportati di seguito.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Accedi tramite SSH a worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Imposta il nome della TPU

export TPU_NAME=${TPU_NAME}

Esegui DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Esegui script.sh:

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

I seguenti flag sono necessari per eseguire i carichi di lavoro dei consigli (DLRM DCN):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Risultati del benchmarking

La sezione seguente contiene i risultati del benchmarking per DLRM DCN v2 e MaxDiffusion su v6e.

DLRM DCN v2

Lo script di addestramento DLRM DCN v2 è stato eseguito su scale diverse. Consulta le portate nella tabella seguente.

v6e-64 v6e-128 v6e-256
Passaggi di addestramento 7000 7000 7000
Dimensione del batch globale 131072 262144 524288
Velocità effettiva (esempi/sec) 2975334 5111808 10066329

MaxDiffusion

Abbiamo eseguito lo script di addestramento per MaxDiffusion su una v6e-4, una v6e-16 e una 2xv6e-16. Consulta le portate nella tabella seguente.

v6e-4 v6e-16 Due v6e-16
Passaggi di addestramento 0,069 0,073 0,13
Dimensione del batch globale 8 32 64
Velocità effettiva (esempi/sec) 115,9 438,4 492,3

Raccolte

La versione 6e introduce una nuova funzionalità denominata raccolte a beneficio degli utenti che gestiscono i workload di pubblicazione. La funzionalità delle raccolte si applica solo alla versione 6e.

Le raccolte ti consentono di indicare a Google Cloud quali dei tuoi nodi TPU fanno parte di un carico di lavoro di pubblicazione. In questo modo, l'infrastruttura di Google Cloud sottostante può limitare e semplificare le interruzioni che potrebbero essere applicate ai carichi di lavoro di addestramento nel normale corso delle operazioni.

Utilizzare le raccolte dall'API Cloud TPU

Una raccolta con un solo host nell'API Cloud TPU è una risorsa in coda su cui è impostato un flag speciale (--workload-type = availability-optimized) per indicare all'infrastruttura di base che deve essere utilizzata per l'erogazione dei carichi di lavoro.

Il seguente comando esegue il provisioning di una raccolta con un solo host utilizzando l'API Cloud TPU:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Monitoraggio e profilazione

Cloud TPU v6e supporta il monitoraggio e la profilazione utilizzando gli stessi metodi delle generazioni precedenti di Cloud TPU. Per ulteriori informazioni sul monitoraggio, consulta Monitorare le VM TPU.