Introduzione a Trillium (v6e)
In questa documentazione, nell'API TPU e nei log, v6e viene utilizzato per fare riferimento a Trillium. v6e rappresenta la sesta generazione di TPU di Google.
Con 256 chip per pod, l'architettura v6e condivide molte somiglianze con v5e. Questo sistema è ottimizzato per l'addestramento, il perfezionamento e la pubblicazione di trasformatori, modelli di sintesi di immagini dal testo e reti neurali convoluzionali (CNN).
Per maggiori informazioni sull'architettura e sulle configurazioni del sistema v6e, consulta TPU v6e.
Questo documento introduttivo si concentra sui processi di addestramento e servizio dei modelli utilizzando i framework JAX o PyTorch. Con ogni framework, puoi eseguire il provisioning delle TPU utilizzando le risorse in coda o GKE. La configurazione di GKE può essere eseguita utilizzando XPK o i comandi GKE.
Procedura generale per addestrare o pubblicare un modello utilizzando v6e
- Preparare un progetto Google Cloud
- Capacità sicura
- Esegui il provisioning dell'ambiente Cloud TPU
- Esegui un carico di lavoro di addestramento o inferenza di un modello
Prepara un progetto Google Cloud
Prima di poter utilizzare Cloud TPU, devi:
- Crea un Google Cloud account e un progetto con la fatturazione abilitata
- Installa i componenti alpha di Google Cloud CLI
- Abilita l'API Cloud TPU
- Crea un service agent Cloud TPU
- Crea un account di servizio Cloud TPU e concedi le autorizzazioni
Per saperne di più, vedi Configurare l'ambiente Cloud TPU.
Capacità sicura
Contatta l'Google Cloud assistenza per richiedere la quota Cloud TPU v6e e per rispondere a qualsiasi domanda sulla capacità.
Esegui il provisioning dell'ambiente Cloud TPU
v6e Cloud TPU può essere sottoposta a provisioning e gestita con GKE, con GKE e XPK (uno strumento CLI wrapper su GKE) o come risorse in coda.
Prerequisiti
- Verifica che il tuo progetto disponga di una quota
TPUS_PER_TPU_FAMILY
sufficiente, che specifica il numero massimo di chip a cui puoi accedere all'interno del tuo progetto Google Cloud. - v6e è stato testato con la seguente configurazione:
- Python
3.10
o versioni successive - Versioni software Nightly:
- Nightly JAX
0.4.32.dev20240912
- LibTPU notturna
0.1.dev20240912+nightly
- Nightly JAX
- Versioni software stabili:
- JAX + JAX Lib v0.4.37
- Python
Verifica che il tuo progetto disponga di una quota sufficiente per:
- Quota VM Cloud TPU
- Quota di indirizzi IP
Quota per Hyperdisk bilanciato e per qualsiasi altro tipo di disco che vuoi utilizzare
Se utilizzi GKE con XPK, consulta Autorizzazioni di Cloud Console per l'utente o il service account per le autorizzazioni necessarie per eseguire XPK.
Crea variabili di ambiente
In Cloud Shell, crea le seguenti variabili di ambiente:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descrizioni dei flag dei comandi
Variabile | Descrizione |
NODE_ID | L'ID assegnato dall'utente della Cloud TPU creata quando viene allocata la richiesta di risorse in coda. |
PROJECT_ID | Google Cloud nome del progetto. Utilizza un progetto esistente o creane uno nuovo. Per saperne di più, vedi Configurare il progetto Google Cloud . |
ZONE | Consulta il documento Regioni e zone di Cloud TPU per le zone supportate. |
ACCELERATOR_TYPE | Vedi Tipi di acceleratore. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Questo è l'indirizzo email del tuo account di servizio, che puoi trovare in
Google Cloud Console -> IAM -> Service Accounts
Ad esempio: |
NUM_SLICES | Il numero di sezioni da creare (necessario solo per Multislice). |
QUEUED_RESOURCE_ID | L'ID testo assegnato dall'utente della richiesta di risorsa in coda. |
VALID_DURATION | La durata di validità della richiesta di risorse in coda. |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Ottimizzare le prestazioni di rete
Per ottenere le migliori prestazioni,utilizza una rete con MTU (unità massima di trasmissione) di 8896.
Per impostazione predefinita, un Virtual Private Cloud (VPC) fornisce solo un MTU di 1460 byte,che offre prestazioni di rete non ottimali. Puoi impostare l'MTU di una rete VPC su qualsiasi valore compreso tra 1300 byte e 8896 byte (inclusi). Le dimensioni MTU personalizzate comuni sono 1500 byte (Ethernet standard) o 8896 byte (il massimo possibile). Per maggiori informazioni, consulta Dimensioni MTU valide per le reti VPC.
Per saperne di più sulla modifica dell'impostazione MTU per una rete esistente o predefinita, consulta Modificare l'impostazione MTU di una rete VPC.
L'esempio seguente crea una rete con MTU 8896.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Utilizzo di più NIC (opzione per Multislice)
Le seguenti variabili di ambiente sono necessarie per una subnet secondaria quando utilizzi un ambiente multislice.
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Utilizza i seguenti comandi per creare il routing IP personalizzato per la rete e la subnet.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
Dopo aver creato una sezione di rete multipla, puoi verificare che entrambe le schede di interfaccia di rete (NIC) vengano utilizzate configurando un cluster XPK e aggiungendo il flag --command ifconfig
al comando di creazione del workload XPK.
Utilizza il seguente comando workload create
per visualizzare l'output del comando ifconfig
nei log della console Google Cloud e verifica che sia eth0 sia eth1 abbiano mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
Se vuoi attivare i log di debug o utilizzare Vertex AI TensorBoard, aggiungi i seguenti argomenti facoltativi al comando:
--enable-debug-logs \ --use-vertex-tensorboard
Verifica che sia eth0 sia eth1 abbiano mtu=8896. Puoi verificare che la multi-NIC
sia in esecuzione aggiungendo il flag --command ifconfig
al comando di creazione del workload XPK. Controlla l'output di questo carico di lavoro XPK nei log della console e verifica che sia eth0 sia eth1 abbiano mtu=8896. Google Cloud
Migliorare le impostazioni TCP
Se hai creato le tue Cloud TPU utilizzando l'interfaccia delle risorse in coda, puoi eseguire il seguente comando per migliorare le prestazioni di rete aumentando i limiti del buffer di ricezione TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Provisioning con risorse in coda
Puoi creare una Cloud TPU v6e utilizzando le risorse in coda. Le risorse in coda ti consentono di ricevere capacità non appena diventa disponibile. Puoi specificare un'ora di inizio e di fine facoltativa per quando deve essere soddisfatta la richiesta. Per saperne di più, consulta Gestire le risorse in coda.
Provisioning di Cloud TPU v6e con GKE o XPK
Se utilizzi i comandi GKE con v6e, puoi utilizzare i comandi Kubernetes o XPK per eseguire il provisioning delle Cloud TPU e addestrare o pubblicare modelli. Consulta Pianificare le Cloud TPU in GKE per scoprire come pianificare le configurazioni Cloud TPU nei cluster GKE. Le sezioni seguenti forniscono comandi per creare un cluster XPK con supporto di una singola NIC e di più NIC.
Crea un cluster XPK con supporto NIC singola
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descrizioni dei flag dei comandi
Variabile | Descrizione |
CLUSTER_NAME | Il nome assegnato dall'utente per il cluster XPK. |
PROJECT_ID | Google Cloud nome del progetto. Utilizza un progetto esistente o creane uno nuovo. Per saperne di più, vedi Configurare il progetto Google Cloud . |
ZONE | Consulta il documento Regioni e zone di Cloud TPU per le zone supportate. |
TPU_TYPE | Vedi Tipi di acceleratore. |
NUM_SLICES | Il numero di sezioni che vuoi creare |
CLUSTER_ARGUMENTS | La rete e la subnet da utilizzare.
Ad esempio: |
NUM_SLICES | Il numero di sezioni da creare. |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Crea un cluster XPK con supporto multi-NIC
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Descrizioni dei flag dei comandi
Variabile | Descrizione |
CLUSTER_NAME | Il nome assegnato dall'utente per il cluster XPK. |
PROJECT_ID | Google Cloud nome del progetto. Utilizza un progetto esistente o creane uno nuovo. Per saperne di più, vedi Configurare il progetto Google Cloud . |
ZONE | Consulta il documento Regioni e zone di Cloud TPU per le zone supportate. |
TPU_TYPE | Vedi Tipi di acceleratore. |
NUM_SLICES | Il numero di sezioni che vuoi creare |
CLUSTER_ARGUMENTS | La rete e la subnet da utilizzare.
Ad esempio: |
NODE_POOL_ARGUMENTS | La rete di nodi aggiuntiva da utilizzare.
Ad esempio: |
NUM_SLICES | Il numero di sezioni da creare (necessario solo per Multislice). |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Configurazione del framework
Questa sezione descrive la procedura di configurazione generale per l'addestramento del modello di machine learning utilizzando i framework JAX e PyTorch. Se utilizzi GKE, puoi utilizzare XPK o i comandi Kubernetes per la configurazione del framework.
Configurazione di JAX
Questa sezione fornisce istruzioni di configurazione per l'esecuzione di carichi di lavoro JAX su GKE, con o senza XPK, nonché l'utilizzo di risorse in coda.
Configura JAX utilizzando GKE
Singola sezione su un singolo host
L'esempio seguente configura un pool di nodi a host singolo 2x2 utilizzando un file YAML di Kubernetes.
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 4
Singola sezione su più host
L'esempio seguente configura un pool di nodi multihost 4x4 utilizzando un file YAML di Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 16
Multislice su più host
L'esempio seguente configura due pool di nodi multihost 4x4 utilizzando un file YAML di Kubernetes.
Come prerequisito, devi installare JobSet v0.2.3 o versioni successive.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 32
Per saperne di più, consulta Esegui un workload multislice nella documentazione di GKE.
Per prestazioni migliori, attiva hostNetwork.
Multi-NIC
Per utilizzare il seguente manifest multi-NIC, devi configurare le tue reti. Per ulteriori informazioni, consulta Configurare il supporto di più reti per i pod Kubernetes.
Per sfruttare le funzionalità multi-NIC in GKE, devi includere alcune annotazioni aggiuntive nel manifest del pod Kubernetes.
Di seguito è riportato un esempio di manifest del carico di lavoro multi-NIC non TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Se utilizzi il comando exec
per connetterti al pod Kubernetes, dovresti visualizzare
la NIC aggiuntiva utilizzando il seguente codice:
$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configura JAX utilizzando GKE con XPK
Per configurare JAX utilizzando GKE e XPK, consulta il file README di XPK.
Per configurare ed eseguire XPK con MaxText, consulta Come eseguire MaxText.
Configurare JAX utilizzando le risorse in coda
Installa JAX su tutte le VM Cloud TPU nella tua slice o nelle tue slice contemporaneamente utilizzando il comando
gcloud alpha compute tpus tpu-vm ssh
. Per Multislice, aggiungi il flag --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Puoi eseguire questo comando per verificare quanti core Cloud TPU sono disponibili nella tua slice e per verificare che tutto sia installato correttamente:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
L'output è simile al seguente quando viene eseguito su una sezione v6e-16:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
mostra il numero totale di chip nella sezione specificata.
jax.local_device_count()
indica il numero di chip accessibili da una singola VM in questa sezione.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
pip install setuptools==59.6.0 &&
pip install -r requirements.txt && pip install .'
Risolvere i problemi di configurazione di JAX
Un suggerimento generale è quello di abilitare la registrazione dettagliata nel manifest del carico di lavoro GKE. Quindi, fornisci i log all'assistenza GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Messaggi di errore
no endpoints available for service 'jobset-webhook-service'
Questo errore indica che il jobset non è stato installato correttamente. Controlla se i pod Kubernetes di deployment di jobset-controller-manager sono in esecuzione. Per saperne di più, consulta la documentazione per la risoluzione dei problemi relativi a JobSet.
TPU initialization failed: Failed to connect
Assicurati che la versione del nodo GKE sia 1.30.4-gke.1348000 o successiva (GKE 1.31 non è supportato).
Configurazione per PyTorch
Questa sezione descrive come iniziare a utilizzare PJRT su v6e con PyTorch/XLA. Python 3.10 è la versione di Python consigliata.
Configura PyTorch utilizzando GKE con XPK
Puoi utilizzare il seguente container Docker con XPK in cui sono già installate le dipendenze di PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Per creare un workload XPK, utilizza il seguente comando:
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --enable-debug-logs \ --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
L'utilizzo di --base-docker-image
crea una nuova immagine Docker con la directory di lavoro corrente integrata nel nuovo Docker.
Configurare PyTorch utilizzando le risorse in coda
Segui questi passaggi per installare PyTorch utilizzando le risorse in coda ed eseguire un piccolo script su v6e.
Installa le dipendenze utilizzando SSH per accedere alle VM
Utilizza il seguente comando per installare le dipendenze su tutte le VM Cloud TPU. Per
Multislice, aggiungi il flag --worker=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Migliora il rendimento dei modelli con allocazioni frequenti e di grandi dimensioni
Per i modelli con allocazioni frequenti e di grandi dimensioni, l'utilizzo della funzione tcmalloc
migliora significativamente le prestazioni rispetto all'implementazione predefinita della funzione malloc
, pertanto la funzione malloc
predefinita utilizzata sulla VM Cloud TPU è
tcmalloc
. Tuttavia, a seconda del carico di lavoro (ad esempio, con DLRM, che
ha allocazioni molto grandi per le tabelle di incorporamento), la funzione tcmalloc
potrebbe
causare un rallentamento, nel qual caso puoi provare a impostare la seguente variabile
utilizzando invece la funzione malloc
predefinita:
unset LD_PRELOAD
Utilizza uno script Python per eseguire un calcolo sulla VM v6e
Utilizza il seguente comando per eseguire uno script che crea due tensori, li somma e stampa il risultato:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker all \
--command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Viene generato un output simile al seguente:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
v6e con SkyPilot
Puoi utilizzare Cloud TPU v6e con SkyPilot. Segui questi passaggi per aggiungere a SkyPilot informazioni su località e prezzi relative a v6e. Per saperne di più, consulta l'esempio di TPU v6e di SkyPilot.
Tutorial sull'inferenza
I seguenti tutorial mostrano come eseguire l'inferenza su Cloud TPU v6e:
Esempi di addestramento
Le sezioni seguenti forniscono esempi per l'addestramento di modelli MaxText, MaxDiffusion e PyTorch su Cloud TPU v6e.
Addestramento di MaxText e MaxDiffusion su VM Cloud TPU v6e
Le sezioni seguenti illustrano il ciclo di vita dell'addestramento dei modelli MaxText e MaxDiffusion.
In generale, i passaggi di alto livello sono:
- Crea l'immagine di base del workload.
- Esegui il workload utilizzando XPK.
- Crea il comando di addestramento per il workload.
- Esegui il deployment del carico di lavoro.
- Segui il workload e visualizza le metriche.
- Elimina il workload XPK se non è necessario.
- Elimina il cluster XPK quando non è più necessario.
Crea l'immagine di base
Installa MaxText o MaxDiffusion e crea l'immagine Docker:
Clona il repository che vuoi utilizzare e passa alla directory del repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Configura Docker in modo che utilizzi Google Cloud CLI:
gcloud auth configure-docker
Crea l'immagine Docker utilizzando il seguente comando o JAX Stable Stack. Per saperne di più su JAX Stable Stack, consulta Creare un'immagine Docker con JAX Stable Stack.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Imposta l'ID progetto nella configurazione dell'interfaccia alla gcloud CLI attiva:
gcloud config set project ${PROJECT_ID}
Se avvii il carico di lavoro da una macchina in cui l'immagine non è creata localmente, carica l'immagine.
Imposta la variabile di ambiente
CLOUD_IMAGE_NAME
:export CLOUD_IMAGE_NAME=${USER}_runner
Carica l'immagine:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Esegui il carico di lavoro utilizzando XPK
Imposta le seguenti variabili di ambiente se non utilizzi i valori predefiniti impostati da MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Crea lo script del modello. Questo script verrà copiato come comando di addestramento in un passaggio successivo.
Non eseguire ancora lo script del modello.
MaxText
MaxText è un LLM open source ad alte prestazioni e altamente scalabile scritto in Python e JAX puri e destinato a TPU e GPU per l'addestramento e l'inferenza. Google Cloud
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma è una famiglia di LLM con pesi aperti sviluppati da Google DeepMind, basati sulla ricerca e sulla tecnologia Gemini.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral è un modello di AI all'avanguardia sviluppato da Mistral AI, che utilizza un'architettura di tipo sparse mixture-of-experts (MoE).
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama è una famiglia di LLM open-weight sviluppati da Meta.
Per un esempio di come eseguire Llama3 su PyTorch, consulta i modelli torch_xla nel repository torchprime.
MaxDiffusion
MaxDiffusion è una raccolta di implementazioni di riferimento di vari modelli di diffusione latente scritti in Python e JAX puri che vengono eseguiti su dispositivi XLA, tra cui Cloud TPU e GPU. Stable Diffusion è un modello latente da testo a immagine che genera immagini fotorealistiche da qualsiasi input di testo.
Per eseguire MaxDiffusion, devi installare un ramo Git specifico come mostrato nel seguente comando
git clone
.Script di addestramento:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 resolution=1024 per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
Esporta le seguenti variabili:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Descrizioni delle variabili di ambiente
Variabile Descrizione CLUSTER_NAME
Il nome del tuo cluster XPK. ACCELERATOR_TYPE
Consulta la sezione Tipi di acceleratore. NUM_SLICES
Il numero di sezioni TPU. YOUR_MODEL_SCRIPT
Lo script del modello da eseguire come comando di addestramento. Esegui il modello utilizzando lo script creato nel passaggio precedente. Devi specificare il flag
--base-docker-image
per utilizzare l'immagine di base MaxText o il flag--docker-image
e l'immagine che vuoi utilizzare.(Facoltativo) Puoi attivare la registrazione di debug includendo il flag
--enable-debug-logs
. Per ulteriori informazioni, consulta Debug di JAX su MaxText.(Facoltativo) Puoi creare un esperimento Vertex AI per caricare i dati in Vertex AI TensorBoard includendo il flag
--use-vertex-tensorboard
. Per saperne di più, consulta Monitorare JAX su MaxText utilizzando Vertex AI.python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command="${YOUR_MODEL_SCRIPT}"
L'output include un link per monitorare il carico di lavoro. Apri il link e fai clic sulla scheda Log per monitorare il carico di lavoro in tempo reale.
Eseguire il debug di JAX su MaxText
Utilizza i comandi XPK supplementari per diagnosticare il motivo per cui il cluster o il carico di lavoro non è in esecuzione:
- Elenco dei workload XPK
- XPK inspector
- Abilita la registrazione dettagliata nei log del workload utilizzando il flag
--enable-debug-logs
quando crei il workload XPK
Monitorare JAX su MaxText utilizzando Vertex AI
Per utilizzare TensorBoard, il tuo account utente Google Cloud deve disporre del ruolo aiplatform.user
. Esegui questo comando per concedere il ruolo:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Visualizza i dati scalari e di profilo tramite TensorBoard gestito di Vertex AI.
Aumenta le richieste di gestione delle risorse (CRUD) per la zona che utilizzi da 600 a 5000. Questo potrebbe non essere un problema per i piccoli carichi di lavoro che utilizzano meno di 16 VM.
Installa le dipendenze come
cloud-accelerator-diagnostics
per Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea il cluster XPK utilizzando il flag
--create-vertex-tensorboard
, come documentato in Crea Vertex AI TensorBoard. Puoi eseguire questo comando anche sui cluster esistenti.Crea l'esperimento Vertex AI durante l'esecuzione del carico di lavoro XPK utilizzando il flag
--use-vertex-tensorboard
e il flag facoltativo--experiment-name
. Per l'elenco completo dei passaggi, vedi Crea Vertex AI Experiment per caricare i dati in Vertex AI TensorBoard.
I log includono un link a Vertex AI TensorBoard, simile al seguente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Puoi trovare il link a Vertex AI TensorBoard anche nella console Google Cloud . Vai a Vertex AI Experiments nella console Google Cloud . Seleziona la regione appropriata dal menu a discesa.
La directory TensorBoard viene scritta anche nel bucket Cloud Storage
che hai specificato con ${BASE_OUTPUT_DIR}
.
Elimina i workload XPK
Utilizza il comando xpk workload delete
per eliminare uno o più workload in base al prefisso del job o allo stato del job. Questo
comando potrebbe essere utile se hai inviato carichi di lavoro XPK che non devono più essere eseguiti
o se hai job bloccati nella coda.
Elimina il cluster XPK
Utilizza il comando xpk cluster delete
per eliminare un cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Addestramento di Llama e PyTorch/XLA su una VM Cloud TPU v6e
Questo tutorial descrive come addestrare i modelli Llama utilizzando PyTorch/XLA su Cloud TPU v6e utilizzando il set di dati WikiText.
Accedere a Hugging Face e al modello Llama 3
Per eseguire questo tutorial, devi disporre di un token di accesso utente Hugging Face. Per informazioni sulla creazione di token di accesso utente, consulta la documentazione di Hugging Face sui token di accesso utente.
Devi anche disporre dell'autorizzazione per accedere al modello Llama-3-8B su Hugging Face. Per ottenere l'accesso, vai al modello Meta-Llama-3-8B su HuggingFace e richiedi l'accesso.
Crea una VM Cloud TPU
Crea una Cloud TPU v6e con 8 chip per eseguire il tutorial.
Imposta le variabili di ambiente:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-8 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration
Crea una VM Cloud TPU:
gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installazione
Installa il fork pytorch-tpu/transformers
di Hugging Face Transformers e le dipendenze. Questo tutorial è stato testato con le seguenti versioni delle dipendenze utilizzate in questo esempio:
torch
: compatibile con la versione 2.5.0torch_xla[tpu]
: compatibile con la versione 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Configura le configurazioni del modello
Il comando di addestramento nella sezione successiva, Esegui il modello, utilizza due file di configurazione JSON per definire i parametri del modello e la configurazione Fully Sharded Data Parallel (FSDP). Lo sharding FSDP ti consente di utilizzare una dimensione del batch maggiore durante l'addestramento suddividendo i pesi del modello su più TPU. Quando esegui l'addestramento con modelli più piccoli, potrebbe essere sufficiente utilizzare il parallelismo dei dati e replicare i pesi su ogni dispositivo. Per saperne di più su come partizionare i tensori tra i dispositivi in PyTorch/XLA, consulta la guida per l'utente di PyTorch/XLA SPMD.
Crea il file di configurazione dei parametri del modello. Di seguito è riportata la configurazione dei parametri del modello per Llama-3-8B. Per altri modelli, trova la configurazione su Hugging Face. Ad esempio, vedi la configurazione Llama-2-7B.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Crea il file di configurazione FSDP:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Per ulteriori informazioni su FSDP, consulta FSDPv2.
Carica i file di configurazione nelle tue VM Cloud TPU utilizzando il seguente comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Esegui il modello
Utilizzando i file di configurazione creati nella sezione precedente, esegui lo script run_clm.py
per addestrare il modello Llama-3-8B sul set di dati WikiText. L'esecuzione dello script di addestramento richiede circa 10 minuti su una Cloud TPU v6e-8.
Accedi a Hugging Face sulla tua Cloud TPU utilizzando il seguente comando:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Esegui l'addestramento del modello:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Risoluzione dei problemi di PyTorch/XLA
Se hai impostato le variabili facoltative per il debug nella sezione precedente,
il profilo per il modello verrà archiviato nella posizione specificata dalla
variabile PROFILE_LOGDIR
. Puoi estrarre il file xplane.pb
archiviato
in questa posizione e utilizzare tensorboard
per visualizzare i profili nel
browser seguendo le istruzioni di TensorBoard.
Se PyTorch/XLA non funziona come previsto, consulta la Guida alla risoluzione dei problemi, che contiene suggerimenti per il debug, la profilazione e l'ottimizzazione del modello.
Risultati del benchmarking
La sezione seguente contiene i risultati del benchmark per MaxDiffusion su v6e.
MaxDiffusion
Abbiamo eseguito lo script di addestramento per MaxDiffusion su v6e-4, v6e-16 e due v6e-16. Consulta i throughput nella tabella seguente.
v6e-4 | v6e-16 | Due v6e-16 | |
---|---|---|---|
Passaggi di addestramento | 0,069 | 0,073 | 0,13 |
Dimensione batch globale | 8 | 32 | 64 |
Throughput (esempi/sec) | 115,9 | 438,4 | 492,3 |