Addestra un modello utilizzando TPU v6e

Questo documento ti guida nell'addestramento dei modelli su Cloud TPU v6e (chiamata anche Trillium), trattando la configurazione dell'ambiente, l'ottimizzazione delle prestazioni ed esempi pratici di addestramento utilizzando JAX e PyTorch/XLA.

La TPU v6e, chiamata anche Trillium, è la TPU di sesta generazione di Google. Su tutte le piattaforme tecniche, come l'API e i log, e in tutto questo documento, Trillium verrà indicato come v6e. Con 256 chip per pod, l'architettura di TPU v6e condivide molte somiglianze con v5e. La TPU v6e è ottimizzata per l'addestramento, l'ottimizzazione e la pubblicazione di trasformatori, modelli di sintesi di immagini dal testo e reti neurali convoluzionali (CNN). Per saperne di più sull'architettura e sulle configurazioni del sistema TPU v6e, consulta TPU v6e.

Per informazioni sull'esecuzione dell'inferenza su Cloud TPU v6e, consulta i seguenti tutorial:

Prima di iniziare

Prima di iniziare, devi:

  • Crea un Google Cloud account e un progetto con la fatturazione abilitata
  • Installa i componenti alpha di Google Cloud CLI
  • Abilita l'API Cloud TPU
  • Crea un service agent Cloud TPU
  • Crea un account di servizio Cloud TPU e concedi le autorizzazioni

Per saperne di più, vedi Configurare l'ambiente Cloud TPU.

Verificare la quota e le autorizzazioni

Verifica che il tuo progetto disponga delle seguenti quote:

Se utilizzi GKE con XPK, devi disporre di autorizzazioni aggiuntive nella console Google Cloud . Per maggiori informazioni, consulta Autorizzazioni necessarie nella consoleGoogle Cloud .

Provisioning delle TPU

Puoi eseguire il provisioning e gestire TPU v6e utilizzando i seguenti metodi:

  • GKE: puoi utilizzare GKE per eseguire il provisioning e gestire le TPU come pool di acceleratori per i carichi di lavoro di machine learning containerizzati. Per maggiori informazioni, consulta Informazioni sulle TPU in GKE.
  • GKE e XPK: XPK è uno strumento a riga di comando che semplifica la creazione di cluster e l'esecuzione dei carichi di lavoro su GKE. È progettato per consentire ai professionisti del machine learning di eseguire il provisioning delle TPU ed eseguire job di addestramento senza richiedere una profonda conoscenza di Kubernetes. Per maggiori informazioni, consulta il repository GitHub di XPK.
  • Risorse Cloud TPU in coda: le risorse in coda ti consentono di richiedere capacità TPU di cui viene eseguito il provisioning quando diventa disponibile. È ideale per job batch e carichi di lavoro a tolleranza di errore che possono attendere in una coda. Puoi specificare un intervallo di tempo per la tua richiesta. Per saperne di più, consulta Gestire le risorse in coda.

Provisioning di Cloud TPU v6e con GKE e XPK

Se utilizzi i comandi GKE con v6e, puoi utilizzare i comandi Kubernetes o XPK per eseguire il provisioning delle Cloud TPU e addestrare o pubblicare modelli. Consulta Pianificare le Cloud TPU in GKE per scoprire come pianificare le configurazioni Cloud TPU nei cluster GKE. Le seguenti sezioni forniscono comandi per creare un cluster XPK con supporto di una singola NIC e di più NIC.

Crea un cluster XPK con supporto per una singola NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Descrizioni dei flag dei comandi

Variabile Descrizione
CLUSTER_NAME Il nome assegnato dall'utente per il cluster XPK.
PROJECT_ID Google Cloud nome del progetto. Utilizza un progetto esistente o creane uno nuovo. Per saperne di più, vedi Configurare il progetto Google Cloud .
ZONE Consulta il documento Regioni e zone di Cloud TPU per le zone supportate.
TPU_TYPE Vedi Tipi di acceleratore.
NUM_SLICES Il numero di sezioni che vuoi creare
CLUSTER_ARGUMENTS La rete e la subnet da utilizzare.

Ad esempio: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Il numero di sezioni da creare.
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Crea un cluster XPK con supporto multi-NIC

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descrizioni dei flag dei comandi

Variabile Descrizione
CLUSTER_NAME Il nome assegnato dall'utente per il cluster XPK.
PROJECT_ID Google Cloud nome del progetto. Utilizza un progetto esistente o creane uno nuovo. Per saperne di più, vedi Configurare il progetto Google Cloud .
ZONE Consulta il documento Regioni e zone di Cloud TPU per le zone supportate.
TPU_TYPE Vedi Tipi di acceleratore.
NUM_SLICES Il numero di sezioni che vuoi creare
CLUSTER_ARGUMENTS La rete e la subnet da utilizzare.

Ad esempio: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS La rete di nodi aggiuntiva da utilizzare.

Ad esempio: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Il numero di sezioni da creare (necessario solo per Multislice).
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Configurare JAX o PyTorch

Le seguenti risorse mostrano come configurare JAX o PyTorch sulla tua Cloud TPU, a seconda del metodo di provisioning e gestione che utilizzi:

Per configurare ed eseguire XPK con MaxText, consulta Esecuzione di MaxText su larga scala con XPK .

Ottimizzare le prestazioni di rete

Questa sezione descrive come ottimizzare le prestazioni della rete configurando l'unità massima di trasmissione (MTU), utilizzando più NIC per gli ambienti Multislice e migliorando le impostazioni TCP.

Configura MTU

Per ottenere le migliori prestazioni di rete, utilizza una rete con MTU (unità massima di trasmissione) di 8896.

Per impostazione predefinita, un Virtual Private Cloud (VPC) fornisce solo un'MTU di 1460 byte, che offre prestazioni di rete non ottimali. Puoi impostare l'MTU di una rete VPC su qualsiasi valore compreso tra 1300 byte e 8896 byte (inclusi). Le dimensioni MTU personalizzate comuni sono 1500 byte (Ethernet standard) o 8896 byte (il massimo possibile). Per maggiori informazioni, consulta Dimensioni MTU valide per le reti VPC.

Per saperne di più su come modificare l'impostazione MTU per una rete esistente o predefinita, consulta Modificare l'impostazione MTU di una rete VPC.

L'esempio seguente crea una rete con MTU 8896 e una regola firewall corrispondente che consente il traffico TCP, ICMP e UDP all'interno della rete.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Sostituisci your-resource-name con un nome di base per la rete e il firewall.

Utilizzare l'opzione multi-NIC per Multislice

Se utilizzi un ambiente Multislice, imposta le seguenti variabili di ambiente, necessarie per una subnet secondaria:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Utilizza i seguenti comandi per creare il routing IP personalizzato per la rete e la subnet.

  1. Crea la rete secondaria.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Crea una subnet per la rete secondaria.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Crea una regola firewall per consentire il traffico all'interno della nuova subnet.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Crea un router Cloud per la rete secondaria.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Crea una configurazione NAT per il router Cloud.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Dopo aver creato una sezione di rete multipla, puoi verificare che entrambe le schede di interfaccia di rete (NIC) vengano utilizzate configurando un cluster XPK e aggiungendo il flag --command ifconfig al comando di creazione del workload XPK.

  1. Utilizza il seguente comando workload create per visualizzare l'output del comando ifconfig nei log della console Google Cloud e verifica che sia eth0 sia eth1 abbiano MTU impostato su 8896.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Se vuoi attivare i log di debug o utilizzare Vertex AI TensorBoard, aggiungi i seguenti argomenti facoltativi al comando:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Verifica che sia eth0 sia eth1 abbiano MTU impostato su 8896 controllando l'output del carico di lavoro XPK nei log della console Google Cloud .

Migliorare le impostazioni TCP

Se hai eseguito il provisioning delle Cloud TPU utilizzando le risorse in coda, puoi eseguire il comando seguente per migliorare le prestazioni di rete aumentando i limiti del buffer di ricezione TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Ottimizzare il rendimento dell'allocazione della memoria

La libreria tcmalloc viene utilizzata per impostazione predefinita sulle VM Cloud TPU per migliorare le prestazioni dei modelli con allocazioni di memoria frequenti e di grandi dimensioni. Questa impostazione viene configurata tramite la variabile di ambiente LD_PRELOAD.

Tuttavia, per alcuni carichi di lavoro (ad esempio DLRM con allocazioni di tabelle di incorporamento molto grandi), tcmalloc può causare un rallentamento. In questi casi, puoi ripristinare la funzione malloc standard annullando l'impostazione della variabile LD_PRELOAD nella sessione shell prima di eseguire lo script di addestramento:

unset LD_PRELOAD

Utilizzare SkyPilot

Puoi utilizzare Cloud TPU v6e con SkyPilot. SkyPilot è un framework open source che semplifica il processo di esecuzione, gestione e scalabilità dei carichi di lavoro AI. Puoi aggiungere a SkyPilot informazioni su prezzi e località correlate a v6e. Per ulteriori informazioni, consulta l'esempio di TPU v6e di SkyPilot.

Esempi di addestramento

Le sezioni seguenti forniscono esempi per l'addestramento di modelli MaxText, MaxDiffusion e PyTorch su Cloud TPU v6e.

Questi esempi sono stati testati con le seguenti versioni software:

  • Python 3.10 o versioni successive
  • Versioni software Nightly:
    • Nightly JAX 0.4.32.dev20240912
    • LibTPU notturna 0.1.dev20240912+nightly
  • Versioni software stabili:
    • JAX + JAX Lib v0.4.37

Addestra MaxText e MaxDiffusion su Cloud TPU v6e

Le sezioni seguenti illustrano il ciclo di vita dell'addestramento dei modelli MaxText e MaxDiffusion.

In generale, i passaggi di alto livello sono:

  1. Crea l'immagine di base del workload.
  2. Esegui il workload utilizzando XPK.
    1. Crea il comando di addestramento per il workload.
    2. Esegui il deployment del carico di lavoro.
  3. Segui il carico di lavoro e visualizza le metriche.
  4. Elimina il workload XPK se non è necessario.
  5. Elimina il cluster XPK quando non è più necessario.

Crea l'immagine di base

Installa MaxText o MaxDiffusion e crea l'immagine Docker:

  1. Clona il repository che vuoi utilizzare e passa alla directory del repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configura Docker in modo che utilizzi Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crea l'immagine Docker utilizzando il seguente comando o un'immagine JAX AI. Per saperne di più sulle immagini AI JAX, consulta Immagini AI JAX.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Imposta l'ID progetto nella configurazione dell'interfaccia alla gcloud CLI attiva:

    gcloud config set project ${PROJECT_ID}
    
  5. Se avvii il workload da una macchina in cui l'immagine non è creata localmente, carica l'immagine.

    1. Imposta la variabile di ambiente CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Carica l'immagine:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Esegui il carico di lavoro utilizzando XPK

  1. Imposta le seguenti variabili di ambiente se non utilizzi i valori predefiniti impostati da MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crea lo script del modello. Questo script verrà copiato come comando di addestramento in un passaggio successivo.

    Non eseguire ancora lo script del modello.

    MaxText

    MaxText è un LLM open source ad alte prestazioni e altamente scalabile scritto in Python e JAX puri e destinato a TPU e GPU per l'addestramento e l'inferenza. Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma è una famiglia di LLM con pesi aperti sviluppati da Google DeepMind, basati sulla ricerca e sulla tecnologia Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral è un modello di AI all'avanguardia sviluppato da Mistral AI, che utilizza un'architettura di tipo mixture-of-experts (MoE) sparsa.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama è una famiglia di LLM con pesi aperti sviluppati da Meta.

    Per un esempio di come eseguire Llama3 su PyTorch, consulta i modelli torch_xla nel repository torchprime.

    MaxDiffusion

    MaxDiffusion è una raccolta di implementazioni di riferimento di vari modelli di diffusione latente scritti in Python e JAX puri che vengono eseguiti su dispositivi XLA, tra cui Cloud TPU e GPU. Stable Diffusion è un modello latente da testo a immagine che genera immagini fotorealistiche da qualsiasi input di testo.

    Per eseguire MaxDiffusion, devi installare un ramo Git specifico come mostrato nel seguente script di addestramento.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Esporta le seguenti variabili:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    CLUSTER_NAME Il nome del tuo cluster XPK.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta Versioni di TPU.
    NUM_SLICES Il numero di sezioni TPU.
    YOUR_MODEL_SCRIPT Lo script del modello da eseguire come comando di addestramento.
  4. Esegui il modello utilizzando lo script creato nel passaggio precedente. Devi specificare il flag --base-docker-image per utilizzare l'immagine di base MaxText oppure specificare il flag --docker-image e l'immagine che vuoi utilizzare.

    Puoi scegliere di aggiungere i seguenti flag facoltativi:

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    L'output include un link per seguire il carico di lavoro. Apri il link e fai clic sulla scheda Log per monitorare il carico di lavoro in tempo reale.

Eseguire il debug di JAX su MaxText

Utilizza i comandi XPK supplementari per diagnosticare il motivo per cui il cluster o il carico di lavoro non è in esecuzione:

Monitorare JAX su MaxText utilizzando Vertex AI

Per utilizzare TensorBoard, il tuo account utente Google Cloud deve disporre del ruolo aiplatform.user. Esegui questo comando per concedere il ruolo:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Visualizza i dati scalari e di profilo tramite TensorBoard gestito da Vertex AI.

  1. Aumenta le richieste di gestione delle risorse (CRUD) per la zona che utilizzi da 600 a 5000. Questo potrebbe non essere un problema per i piccoli carichi di lavoro che utilizzano meno di 16 VM.

  2. Installa le dipendenze come cloud-accelerator-diagnostics per Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea il cluster XPK utilizzando il flag --create-vertex-tensorboard, come documentato in Crea Vertex AI TensorBoard. Puoi eseguire questo comando anche sui cluster esistenti.

  4. Crea l'esperimento Vertex AI quando esegui il carico di lavoro XPK utilizzando il flag --use-vertex-tensorboard e il flag facoltativo --experiment-name. Per l'elenco completo dei passaggi, vedi Crea Vertex AI Experiment per caricare i dati in Vertex AI TensorBoard.

I log includono un link a Vertex AI TensorBoard, simile al seguente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Puoi trovare il link a Vertex AI TensorBoard anche nella console Google Cloud . Vai a Vertex AI Experiments nella console Google Cloud . Seleziona la regione appropriata dal menu a discesa.

La directory TensorBoard viene scritta anche nel bucket Cloud Storage che hai specificato con ${BASE_OUTPUT_DIR}.

Elimina il workload XPK

Utilizza il comando xpk workload delete per eliminare uno o più workload in base al prefisso del job o allo stato del job. Questo comando potrebbe essere utile se hai inviato carichi di lavoro XPK che non devono più essere eseguiti o se hai job bloccati nella coda.

Elimina il cluster XPK

Utilizza il comando xpk cluster delete per eliminare il cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Risultati del benchmarking di MaxDiffusion

Abbiamo eseguito lo script di addestramento per MaxDiffusion su v6e-4, v6e-16 e due v6e-16. La tabella seguente mostra i throughput misurati.

v6e-4 v6e-16 Due v6e-16
Passaggi di addestramento 0,069 0,073 0,13
Dimensione batch globale 8 32 64
Throughput (esempi/sec) 115,9 438,4 492,3

Addestramento di modelli Llama utilizzando PyTorch/XLA su Cloud TPU v6e

Questa sezione descrive come addestrare i modelli Llama utilizzando PyTorch/XLA su Cloud TPU v6e utilizzando il set di dati WikiText.

Accedere a Hugging Face e al modello Llama 3

Per questo esempio è necessario un token di accesso utente Hugging Face. Per informazioni sulla creazione di token di accesso utente, consulta la documentazione di Hugging Face sui token di accesso utente.

Devi anche disporre dell'autorizzazione per accedere al modello Llama-3-8B su Hugging Face. Per ottenere l'accesso, vai al modello Meta-Llama-3-8B su HuggingFace e richiedi l'accesso.

Crea una VM Cloud TPU

Crea una Cloud TPU v6e con 8 chip per questo esempio.

  1. Imposta le variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.

  2. Crea una VM Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Installazione

Installa il fork di Hugging Face Transformers e le dipendenze.pytorch-tpu/transformers Questo esempio è stato testato con le seguenti versioni delle dipendenze:

  • torch: compatibile con la versione 2.5.0
  • torch_xla[tpu]: compatibile con la versione 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Configurare i file di configurazione del modello

Il comando di addestramento nella sezione successiva, Esegui il modello, utilizza due file di configurazione JSON per definire i parametri del modello e la configurazione Fully Sharded Data Parallel (FSDP). Lo sharding FSDP ti consente di utilizzare una dimensione batch maggiore durante l'addestramento eseguendo lo sharding dei pesi del modello su più TPU. Quando esegui l'addestramento con modelli più piccoli, potrebbe essere sufficiente utilizzare il parallelismo dei dati e replicare i pesi su ogni dispositivo. Per saperne di più su come partizionare i tensori tra i dispositivi in PyTorch/XLA, consulta la guida per l'utente di PyTorch/XLA SPMD.

  1. Crea il file di configurazione dei parametri del modello. Di seguito è riportata la configurazione dei parametri del modello per Llama-3-8B. Per altri modelli, trova il file di configurazione su Hugging Face. Ad esempio, consulta la configurazione Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crea il file di configurazione FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Per ulteriori informazioni su FSDP, consulta Parallelismo dei dati completamente partizionati utilizzando SPMD .

  3. Carica i file di configurazione nelle tue VM Cloud TPU utilizzando il seguente comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Esegui il modello

Utilizzando i file di configurazione creati nella sezione precedente, esegui lo script run_clm.py per addestrare il modello Llama-3-8B sul set di dati WikiText. L'esecuzione dello script di addestramento richiede circa 10 minuti su una Cloud TPU v6e-8.

  1. Accedi a Hugging Face sulla tua Cloud TPU utilizzando il seguente comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Esegui l'addestramento del modello:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Risoluzione dei problemi di PyTorch/XLA

Se hai impostato le variabili facoltative per il debug nella sezione precedente, il profilo per il modello verrà memorizzato nella posizione specificata dalla variabile PROFILE_LOGDIR. Puoi estrarre il file xplane.pb archiviato in questa posizione e utilizzare tensorboard per visualizzare i profili nel browser seguendo le istruzioni di TensorBoard.

Se PyTorch/XLA non funziona come previsto, consulta la Guida alla risoluzione dei problemi, che contiene suggerimenti per il debug, la profilazione e l'ottimizzazione del modello.