Introduzione a Trillium (v6e)
v6e viene utilizzato per fare riferimento a Trillium in questa documentazione, nell'API TPU e nei log. v6e rappresenta la sesta generazione di TPU di Google.
Con 256 chip per pod, l'architettura v6e condivide molte somiglianze con la v5e. Questo sistema è ottimizzato per l'addestramento, la messa a punto e la pubblicazione di trasformatori, conversione di testo in immagini e reti neurali convoluzionali (CNN).
Consulta il documento v6e per informazioni sull'architettura e sulle configurazioni del sistema v6e.
Questo documento introduttivo si concentra sulle procedure di addestramento e pubblicazione dei modelli utilizzando i framework JAX, PyTorch o TensorFlow. Con ogni framework, puoi eseguire il provisioning delle TPU utilizzando risorse in coda o Google Kubernetes Engine (GKE). La configurazione di GKE può essere eseguita utilizzando i comandi XPK o GKE.
Procedura generale per addestrare o eseguire il servizio di un modello utilizzando la versione 6e
- Preparare un Google Cloud progetto
- Capacità sicura
- Configurare l'ambiente TPU
- Esegui il provisioning dell'ambiente Cloud TPU
- Esegui un carico di lavoro di addestramento o inferenza del modello
- Pulizia
Prepara un Google Cloud progetto
- Accedi al tuo Account Google. Se non l'hai ancora fatto, registrati per creare un nuovo account.
- Nella console Google Cloud, seleziona o crea un progetto Cloud dalla pagina del selettore di progetti.
- Abilita la fatturazione per il tuo progetto Google Cloud. La fatturazione è obbligatoria per tutto l'utilizzo di Google Cloud.
- Installa i componenti gcloud alpha.
Esegui il seguente comando per installare la versione più recente dei componenti
gcloud
.gcloud components update
Abilita l'API TPU tramite il seguente comando
gcloud
in Cloud Shell. Puoi anche attivarlo dalla console Google Cloud.gcloud services enable tpu.googleapis.com
Abilita le autorizzazioni con l'account di servizio TPU per l'API Compute Engine
Gli account di servizio consentono al servizio Cloud TPU di accedere ad altri servizi Google Cloud. Un account di servizio gestito dall'utente è una best practice di Google Cloud. Segui queste guide per creare e concedere i ruoli. Sono necessari i seguenti ruoli:
- TPU Admin
- Amministratore Storage
- Logs Writer
- Monitoring Metric Writer
a. Configura le autorizzazioni XPK con il tuo account utente per GKE: XPK.
Esegui l'autenticazione con il tuo Account Google e imposta l'ID progetto e la zona predefinite.
auth login
autorizza gcloud ad accedere Google Cloud con le credenziali dell'utente Google.
PROJECT_ID
è il Google Cloud nome del progetto.
ZONE
è la zona in cui vuoi creare la TPU.gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Crea un'identità di servizio per la VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Capacità sicura
Contatta il team di assistenza di vendita/account Cloud TPU per richiedere una quota TPU e per rispondere a eventuali domande sulla capacità.
Esegui il provisioning dell'ambiente Cloud TPU
È possibile eseguire il provisioning e la gestione delle TPU v6e con GKE, con GKE e XPK (uno strumento CLI wrapper su GKE) o come risorse in coda.
Prerequisiti
- Verifica che il tuo progetto disponga di una quota
TPUS_PER_TPU_FAMILY
sufficiente, che specifica il numero massimo di chip a cui puoi accedere all'interno del progettoGoogle Cloud . - La versione 6e è stata testata con la seguente configurazione:
- Python
3.10
o versioni successive - Versioni software Nightly:
- a notte JAX
0.4.32.dev20240912
- nightly LibTPU
0.1.dev20240912+nightly
- a notte JAX
- Versioni software stabili:
- JAX + JAX Lib della versione 0.4.37
- Python
Verifica che il tuo progetto disponga di una quota TPU sufficiente per:
- Quota VM TPU
- Quota di indirizzi IP
Quota Hyperdisk-balance
Autorizzazioni per i progetti utente
- Se utilizzi GKE con XPK, consulta Autorizzazioni di Cloud Console per l'account utente o di servizio per conoscere le autorizzazioni necessarie per eseguire XPK.
Variabili di ambiente
In Cloud Shell, crea le seguenti variabili di ambiente:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descrizioni dei flag dei comandi
Variabile | Descrizione |
NODE_ID | L'ID assegnato dall'utente della TPU che viene creato quando viene allocata la richiesta di risorsa in coda. |
PROJECT_ID | Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo su |
ZONA | Consulta il documento Regioni e zone TPU per le zone supportate. |
ACCELERATOR_TYPE | Consulta la sezione Tipi di acceleratore. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Si tratta dell'indirizzo email del tuo account di servizio che puoi trovare in
Google Cloud Console -> IAM -> Account di servizio
Ad esempio: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Il numero di slice da creare (necessario solo per Multislice). |
QUEUED_RESOURCE_ID | L'ID testo assegnato dall'utente della richiesta di risorsa in coda. |
VALID_DURATION | La durata di validità della richiesta di risorse in coda. |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Ottimizzazioni delle prestazioni di rete
Per ottenere le migliori prestazioni,utilizza una rete con 8896 MTU (unità massima di trasmissione).
Per impostazione predefinita, un Virtual Private Cloud (VPC) fornisce solo un MTU di 1460 byte,che offrirà prestazioni di rete non ottimali. Puoi impostare l'MTU di una rete VPC su qualsiasi valore compreso tra 1300 e 8896 byte (inclusi). Le dimensioni MTU personalizzate comuni sono 1500 byte (Ethernet standard) o 8896 byte (il massimo possibile). Per saperne di più, consulta Dimensioni MTU valide per le reti VPC.
Per saperne di più sulla modifica dell'impostazione MTU per una rete esistente o predefinita, consulta Modificare l'impostazione MTU di una rete VPC.
L'esempio seguente crea una rete con 8896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
Utilizzo di più NIC (opzione per il multislice)
Le seguenti variabili di ambiente sono necessarie per una subnet secondaria quando utilizzi un ambiente Multislice.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Utilizza i seguenti comandi per creare il routing IP personalizzato per la rete e la subnet.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT_ID}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT_ID}" \
--enable-logging
Una volta creato uno slice multirete, puoi verificare che entrambe le NIC siano in uso configurando un cluster XPK ed eseguendo --command ifconfig
nell'ambito del workload XPK.
Utilizza il seguente comando xpk workload
per visualizzare l'output del comando ifconfig
nei log della console Cloud e controlla che sia eth0 che eth1 abbiano mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Verifica che sia eth0 che eth1 abbiano mtu=8896. un modo per verificare che sia in esecuzione il multi-NIC è eseguire il comando --command "ifconfig" nell'ambito del carico di lavoro XPK. Poi controlla l'output stampato del carico di lavoro xpk nei log della console cloud e verifica che sia eth0 che eth1 abbiano mtu=8896.
Impostazioni TCP migliorate
Per le TPU create utilizzando l'interfaccia delle risorse in coda, puoi eseguire il seguente comando per migliorare le prestazioni della rete aumentando i limiti del buffer di ricezione TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" \ --zone "$ZONE" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Provisioning con risorse in coda
È possibile eseguire il provisioning della capacità allocata utilizzando il comando create
queued-resource.
Crea una richiesta di risorse TPU in coda.
Il flag
--reserved
è necessario solo per le risorse riservate, non per le risorse on demand.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved] # The following flags are only needed if you are using Multislice. --node-count node-count # Number of slices in a Multislice \ --node-prefix node-prefix # An optional user-defined node prefix; the default is QUEUED_RESOURCE_ID.
Se la richiesta di risorse in coda viene creata correttamente, lo stato nel campo "response" sarà "WAITING_FOR_RESOURCES" o "FAILED". Se la richiesta di risorse in coda è nello stato "WAITING_FOR_RESOURCES", la risorsa è stata aggiunta alla coda e verrà eseguita il provisioning quando sarà disponibile una capacità TPU sufficiente. Se la richiesta di risorse in coda è in stato "FAILED", il motivo dell'errore sarà nell'output. La richiesta di risorse in coda scadrà se non viene eseguito il provisioning di un v6e entro la durata specificata e lo stato diventa "FAILED". Per ulteriori informazioni, consulta la documentazione pubblica relativa alle risorse in coda.
Quando la richiesta di risorse in coda è in stato "ATTIVO", puoi collegarti alle VM TPU tramite SSH. Utilizza i comandi
list
odescribe
per eseguire query sullo stato della risorsa in coda.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Quando la risorsa in coda è nello stato "ATTIVO", l'output è simile al seguente:
state: state: ACTIVE
Gestisci le VM TPU. Per le opzioni di gestione delle VM TPU, consulta la sezione sulla gestione delle VM TPU.
Connettiti alle VM TPU tramite SSH
Puoi installare i binari su ogni VM TPU nella sezione TPU ed eseguire il codice. Consulta la sezione Tipi di VM per determinare quante VM avrà il tuo slice.
Per installare i binari o eseguire codice, puoi utilizzare SSH per connetterti a una VM utilizzando il comando
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Per utilizzare SSH per connetterti a una VM specifica, utilizza il flag
--worker
che segue un indice basato su 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Se le forme delle sezioni sono superiori a 8 chip, avrai più VM in una sezione. In questo caso, utilizza i parametri
--worker=all
e--command
nel comandogcloud alpha compute tpus tpu-vm ssh
per eseguire un comando su tutte le VM contemporaneamente. Ad esempio:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Elimina una risorsa in coda
Elimina una risorsa in coda alla fine della sessione o rimuovi le richieste di risorse in coda nello stato "FAILED". Per eliminare una risorsa in coda, elimina il segmento e poi la richiesta della risorsa in coda in due passaggi:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Eseguire il provisioning di TPU v6e con GKE o XPK
Se utilizzi i comandi GKE con la versione 6e, puoi utilizzare i comandi Kubernetes o XPK per eseguire il provisioning delle TPU e addestrare o pubblicare i modelli. Consulta la sezione Pianificare le TPU in GKE per scoprire come pianificare le configurazioni TPU nei cluster GKE. Le sezioni seguenti forniscono i comandi per creare un cluster XPK con supporto per una singola NIC e per più NIC.
Comandi per creare un cluster XPK con supporto di una singola NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network ${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster $CLUSTER_NAME \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=$NUM_SLICES \ --tpu-type=$TPU_TYPE \ --zone=$ZONE \ --project=$PROJECT \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descrizioni dei flag dei comandi
Variabile | Descrizione |
CLUSTER_NAME | Il nome assegnato dall'utente al cluster XPK. |
PROJECT_ID | Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo su |
ZONA | Consulta il documento Regioni e zone TPU per le zone supportate. |
TPU_TYPE | Consulta la sezione Tipi di acceleratore. |
NUM_SLICES | Il numero di slice che vuoi creare |
CLUSTER_ARGUMENTS | La rete e la subnet da utilizzare.
Ad esempio: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" |
NUM_SLICES | Il numero di sezioni da creare. |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Comandi per creare un cluster XPK con supporto multi NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create "${NETWORK_NAME_1}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_1}" \ --network="${NETWORK_NAME_1}" \ --range=10.11.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_1}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_1}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create "${NETWORK_NAME_2}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \ --network="${NETWORK_NAME_2}" \ --range=10.10.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_2}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_2}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Descrizioni dei flag dei comandi
Variabile | Descrizione |
CLUSTER_NAME | Il nome assegnato dall'utente al cluster XPK. |
PROJECT_ID | Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo su |
ZONA | Consulta il documento Regioni e zone TPU per le zone supportate. |
TPU_TYPE | Consulta la sezione Tipi di acceleratore. |
NUM_SLICES | Il numero di slice che vuoi creare |
CLUSTER_ARGUMENTS | La rete e la subnet da utilizzare.
Ad esempio: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}" |
NODE_POOL_ARGUMENTS | La rete del nodo aggiuntivo da utilizzare.
Ad esempio: "--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}" |
NUM_SLICES | Il numero di slice da creare (necessario solo per Multislice). |
NETWORK_NAME | Il nome di una rete secondaria da utilizzare. |
NETWORK_FW_NAME | Il nome di un firewall di rete secondario da utilizzare. |
Configurazione del framework
Questa sezione descrive la procedura di configurazione generale per l'addestramento dei modelli di ML utilizzando i framework JAX, PyTorch o TensorFlow. Puoi eseguire il provisioning delle TPU utilizzando le risorse in coda o GKE. La configurazione di GKE può essere eseguita utilizzando i comandi XPK o Kubernetes.
Configurazione per JAX
Questa sezione fornisce esempi per l'esecuzione di workload JAX su GKE, con o senza XPK, nonché per l'utilizzo di risorse in coda.
Configura JAX utilizzando GKE
L'esempio seguente configura un singolo host 2x2 utilizzando un file YAML di Kubernetes.
Singolo slice su un singolo host
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 4
Singolo slice su più host
L'esempio seguente configura un pool di nodi multi-host 4x4 utilizzando un file YAML di Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 16
Multislice su più host
L'esempio seguente configura due pool di nodi multi-host 4x4 utilizzando un file YAML Kubernetes.
Come prerequisito, devi installare JobSet versione 0.2.3 o successive.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:
Total TPU chips: 32
Per saperne di più, consulta Eseguire un workload multislice nella documentazione di GKE.
Per migliorare le prestazioni, abilita hostNetwork.
Multi-NIC
Per sfruttare la multi-NIC in GKE, il manifest del pod Kubernetes deve avere annotazioni aggiuntive. Di seguito è riportato un manifest di esempio per un carico di lavoro non TPU con più NIC.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Se esegui exec
nel pod Kubernetes, dovresti vedere la NIC aggiuntiva
utilizzando il seguente codice.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configura JAX utilizzando GKE con XPK
Consulta un esempio nel file README di xpk.
Per configurare ed eseguire XPK con MaxText, consulta: Come eseguire MaxText.
Configura JAX utilizzando le risorse in coda
Installa JAX su tutte le VM TPU del tuo o dei tuoi slice contemporaneamente utilizzando
gcloud alpha compute tpus tpu-vm ssh
. Per Multislice, aggiungi --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Puoi eseguire il seguente codice Python per controllare quanti core TPU sono disponibili nel tuo slice e per verificare che sia tutto installato correttamente (gli output mostrati qui sono stati prodotti con un slice v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
L'output è simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() mostra il numero totale di chip nel determinato slice. jax.local_device_count() indica il numero di chip accessibili da una singola VM in questo slice.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Risoluzione dei problemi di configurazione di JAX
Un suggerimento generale è abilitare la registrazione dettagliata nel manifest del carico di lavoro GKE. Quindi, fornisci i log all'assistenza GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Messaggi di errore
no endpoints available for service 'jobset-webhook-service'
Questo errore indica che il jobset non è stato installato correttamente. Controlla se i pod Kubernetes del deployment jobset-controller-manager sono in esecuzione. Per ulteriori informazioni, consulta la documentazione sulla risoluzione dei problemi relativi a JobSet.
TPU initialization failed: Failed to connect
Assicurati che la versione del nodo GKE sia 1.30.4-gke.1348000 o successiva (GKE 1.31 non è supportato).
Configurazione per PyTorch
Questa sezione descrive come iniziare a utilizzare PJRT nella versione 6e con PyTorch/XLA. Python 3.10 è la versione consigliata.
Configurare PyTorch utilizzando GKE con XPK
Puoi utilizzare il seguente contenitore Docker con XPK in cui sono già installate le dipendenze di PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Per creare un carico di lavoro XPK, utilizza il seguente comando:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
L'utilizzo di --base-docker-image
crea una nuova immagine Docker con la directory di lavoro corrente integrata nel nuovo Docker.
Configurare PyTorch utilizzando le risorse in coda
Segui questi passaggi per installare PyTorch utilizzando le risorse in coda e per eseguire un piccolo script sulla versione 6e.
Installa le dipendenze utilizzando SSH per accedere alle VM
Per Multislice, aggiungi --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Migliora le prestazioni dei modelli con allocazioni significative e frequenti
Per i modelli con allocazioni frequenti e di grandi dimensioni, abbiamo osservato che l'utilizzo di tcmalloc
migliora notevolmente le prestazioni rispetto all'implementazione predefinita di malloc
, pertanto il valore malloc
predefinito utilizzato sulla VM TPU è tcmalloc
. Tuttavia, a seconda del tuo
caricamento di lavoro (ad esempio, con DLRM che ha allocazioni molto grandi per le sue
tabelle di embedding), tcmalloc
potrebbe causare un rallentamento, nel qual caso potresti provare a
reimpostare la seguente variabile utilizzando malloc
predefinito:
unset LD_PRELOAD
Utilizza uno script Python per eseguire un calcolo sulla VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Viene generato un output simile al seguente:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configurazione per TensorFlow
Per l'Anteprima pubblica della versione 6e, è supportata solo la versione del runtime tf-nightly.
Puoi reimpostare tpu-runtime
con la versione di TensorFlow compatibile con v6e eseguendo i seguenti comandi:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Utilizza SSH per accedere a worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Installa TensorFlow su worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Esporta la variabile di ambiente TPU_NAME
:
export TPU_NAME=v6e-16
Puoi eseguire il seguente script Python per controllare quanti core TPU sono disponibili nel tuo slice e per verificare che tutto sia installato correttamente (gli output mostrati sono stati generati con un slice v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
L'output è simile al seguente:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e con SkyPilot
Puoi utilizzare TPU v6e con SkyPilot. Segui i passaggi che seguono per aggiungere a SkyPilot le informazioni su prezzi/località relativi a v6e.
Aggiungi quanto segue alla fine di
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Specifica le seguenti risorse in un file YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Avvia un cluster con TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Connettiti alla TPU v6e tramite SSH:
ssh tpu_v6
Tutorial sull'inferenza
I seguenti tutorial mostrano come eseguire l'inferenza su TPU v6e:
Esempi di addestramento
Le sezioni seguenti forniscono esempi per l'addestramento di modelli MaxText, MaxDiffusion e PyTorch su TPU v6e.
Addestramento di MaxText e MaxDiffusion su VM Cloud TPU v6e
Le sezioni seguenti illustrano il ciclo di vita dell'addestramento dei modelli MaxText e MaxDiffusion.
In generale, i passaggi di alto livello sono:
- Crea l'immagine di base del workload.
- Esegui il tuo carico di lavoro utilizzando XPK.
- Crea il comando di addestramento per il carico di lavoro.
- Esegui il deployment del carico di lavoro.
- Monitora il carico di lavoro e visualizza le metriche.
- Elimina il workload XPK se non è necessario.
- Elimina il cluster XPK quando non è più necessario.
Crea l'immagine di base
Installa MaxText o MaxDiffusion e crea l'immagine Docker:
Clona il repository che vuoi utilizzare e passa alla directory del repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configura Docker in modo che utilizzi Google Cloud CLI:
gcloud auth configure-docker
Crea l'immagine Docker utilizzando il seguente comando o JAX Stable Stack. Per ulteriori informazioni su JAX Stable Stack, consulta Creare un'immagine Docker con JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Se avvii il carico di lavoro da una macchina su cui l'immagine non è stata compilata localmente, caricala:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Crea un'immagine Docker con JAX Stable Stack
Puoi creare le immagini Docker MaxText e MaxDiffusion utilizzando l'immagine di base JAX Stable Stack.
JAX Stable Stack fornisce un ambiente coerente per MaxText e MaxDiffusion
combinando JAX con pacchetti di base come orbax
, flax
e optax
, insieme a
una libtpu.so ben qualificata che gestisce le utilità di programmazione TPU e altri strumenti essenziali. Queste librerie vengono testate per garantire la compatibilità e fornire una base stabile per creare ed eseguire MaxText e MaxDiffusion. In questo modo, vengono eliminati i potenziali conflitti dovuti a versioni del pacchetto incompatibili.
JAX Stable Stack include una libreria libtpu.so completamente rilasciata e qualificata, la libreria di base che gestisce la compilazione, l'esecuzione e la configurazione della rete ICI dei programmi TPU. La release libtpu sostituisce la build notturna precedentemente utilizzata da JAX e garantisce la funzionalità coerente dei calcoli XLA su TPU con test di qualificazione a livello di PJRT negli IR HLO/StableHLO.
Per creare l'immagine Docker di MaxText e MaxDiffusion con JAX Stable Stack, quando
esegui lo script docker_build_dependency_image.sh
, imposta la variabile MODE
su stable_stack
e la variabile BASEIMAGE
sull'immagine di base che vuoi
utilizzare.
L'esempio seguente specifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
come immagine di base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Per un elenco delle immagini di base JAX Stable Stack disponibili, consulta Immagini JAX Stable Stack in Artifact Registry.
Esegui il carico di lavoro utilizzando XPK
Imposta le seguenti variabili di ambiente se non utilizzi i valori predefiniti impostati da MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Crea lo script del modello. Questo script verrà copiato come comando di addestramento in un passaggio successivo.
Non eseguire ancora lo script del modello.
MaxText
MaxText è un LLM open source ad alte prestazioni e altamente scalabile scritto in Python e JAX puro e che ha come target Google Cloud TPU e GPU per l'addestramento e l'inferenza.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma è una famiglia di LLM con pesi aperti sviluppati da Google DeepMind, basata sulla ricerca e sulla tecnologia di Gemini.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral è un modello di IA all'avanguardia sviluppato da Mistral AI, che utilizza un'architettura sparse mixture-of-experts (MoE).
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama è una famiglia di LLM con pesi aperti sviluppati da Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion è una raccolta di implementazioni di riferimento di vari modelli di diffusione latente scritti in puro Python e JAX che vengono eseguiti su dispositivi XLA, tra cui Cloud TPU e GPU. La diffusione stabile è un modello di testo a immagine latente che genera immagini fotorealistiche da qualsiasi input di testo.
Per eseguire MaxDiffusion, devi installare un ramo Git specifico come mostrato nel seguente comando
git checkout
.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Script di addestramento:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Esegui il modello utilizzando lo script creato nel passaggio precedente. Devi specificare il flag
--base-docker-image
per utilizzare l'immagine di base MaxText o il flag--docker-image
e l'immagine che vuoi utilizzare.(Facoltativo) Puoi attivare la registrazione di log di debug includendo il flag
--enable-debug-logs
. Per ulteriori informazioni, consulta Eseguire il debug di JAX su MaxText.(Facoltativo) Puoi creare un esperimento Vertex AI per caricare i dati in Vertex AI TensorBoard includendo il flag
--use-vertex-tensorboard
. Per ulteriori informazioni, consulta Monitorare JAX su MaxText utilizzando Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=$ACCELERATOR_TYPE \ --num-slices=$NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command $YOUR-MODEL-SCRIPT
Esporta le seguenti variabili:
export CLUSTER_NAME=CLUSTER_NAME: il nome del cluster XPK. export ACCELERATOR_TYPEACCELERATOR_TYPE: la versione e le dimensioni della TPU. Ad esempio:
v6e-256
. export NUM_SLICES=NUM_SLICES: il numero di slice TPU. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: lo script del modello da eseguire come comando di addestramento.L'output include un link per monitorare il carico di lavoro, simile al seguente:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Apri il link e fai clic sulla scheda Log per monitorare il carico di lavoro in tempo reale.
Eseguire il debug di JAX su MaxText
Utilizza i comandi XPK supplementari per diagnosticare il motivo per cui il cluster o il carico di lavoro non è in esecuzione.
- Elenco dei carichi di lavoro XPK
- XPK inspector
- Attiva il logging dettagliato nei log del carico di lavoro utilizzando il flag
--enable-debug-logs
quando crei il carico di lavoro XPK.
Monitorare JAX su MaxText utilizzando Vertex AI
Visualizza i dati scalari e di profilo tramite TensorBoard gestito da Vertex AI.
- Aumenta le richieste di gestione delle risorse (CRUD) per la zona in uso da 600 a 5000. Questo potrebbe non essere un problema per i carichi di lavoro di piccole dimensioni che utilizzano meno di 16 VM.
Installa le dipendenze come
cloud-accelerator-diagnostics
per Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea il tuo cluster XPK utilizzando il flag
--create-vertex-tensorboard
, come documentato in Creare Vertex AI TensorBoard. Puoi anche eseguire questo comando sui cluster esistenti.Crea l'esperimento Vertex AI quando esegui il tuo carico di lavoro XPK utilizzando il flag
--use-vertex-tensorboard
e il flag facoltativo--experiment-name
. Per l'elenco completo dei passaggi, consulta Creare un esperimento Vertex AI per caricare i dati su Vertex AI TensorBoard.
I log includono un link a un Vertex AI TensorBoard, simile al seguente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Puoi trovare il link a Vertex AI TensorBoard anche nella console Google Cloud. Vai a Vertex AI Experiments nella console Google Cloud. Seleziona la regione appropriata dal menu a discesa.
La directory TensorBoard viene scritta anche nel bucket Cloud Storage specificato con ${BASE_OUTPUT_DIR}
.
Eliminare i workload XPK
Utilizza il comando xpk workload delete
per eliminare uno o più workload in base al prefisso o allo stato del job. Questo comando potrebbe essere utile se hai inviato carichi di lavoro XPK che non devono più essere eseguiti o se hai job bloccati nella coda.
Elimina il cluster XPK
Utilizza il comando xpk cluster delete
per eliminare un cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone $ZONE --project $PROJECT_ID
Addestramento di Llama e PyTorch/XLA su una VM Cloud TPU v6e
Questo tutorial descrive come addestrare i modelli Llama utilizzando PyTorch/XLA su TPU v6e utilizzando il set di dati WikiText.
Accedi a Hugging Face e al modello Llama 3
Per eseguire questo tutorial, devi disporre di un token di accesso utente Hugging Face. Per informazioni sulla creazione e sui token di accesso utente, consulta la documentazione di Hugging Face sui token di accesso utente.
Devi anche disporre dell'autorizzazione per accedere al modello Llama 3 8B su Hugging Face. Per ottenere l'accesso, vai al modello Meta-Llama-3-8B su Hugging Face e richiedi l'accesso.
Crea una VM TPU
Crea una TPU v6e con 8 chip per eseguire il tutorial.
Configura le variabili di ambiente:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT=YOUR_PROJECT export ZONE=YOUR_ZONE
Crea una VM TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \ --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT
Installazione
Installa il pytorch-tpu/transformers
fork di
Hugging Face Transformers e le dipendenze. Questo tutorial è stato testato con le seguenti versioni delle dipendenze utilizzate in questo esempio:
torch
: compatibile con 2.5.0torch_xla[tpu]
: compatibile con 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Configura le configurazioni del modello
Il comando di addestramento nella sezione successiva, Esegui il modello, utilizza due file di configurazione JSON per definire i parametri del modello e la configurazione FSDP (Fully Sharded Data Parallel). La suddivisione FSDP viene utilizzata per adattare i pesi del modello a un batch di dimensioni maggiori durante l'addestramento. Quando si esegue l'addestramento con modelli più piccoli, potrebbe essere sufficiente utilizzare il parallelismo dei dati e replicare i pesi su ogni dispositivo. Per ulteriori informazioni su come suddividere i tensori su più dispositivi in PyTorch/XLA, consulta la Guida dell'utente SPMD di PyTorch/XLA.
Crea il file di configurazione dei parametri del modello. Di seguito è riportata la configurazione del parametro del modello per Llama3-8B. Per altri modelli, trova la configurazione su Hugging Face. Ad esempio, consulta Llama2-7B config.
cat > llama-config.json <
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF Crea il file di configurazione FSDP:
cat > fsdp-config.json <
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF Per ulteriori informazioni su FSDP, consulta FSDPv2.
Carica i file di configurazione nelle VM TPU utilizzando il seguente comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Esegui il modello
Utilizzando i file di configurazione creati nella sezione precedente, esegui lo script run_clm.py
per addestrare il modello Llama 3 8B sul set di dati WikiText. L'esecuzione dello script di addestramento su una TPU v6e-8 richiede circa 10 minuti.
Accedi a Hugging Face sulla tua TPU utilizzando il seguente comando:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Esegui l'addestramento del modello:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Risoluzione dei problemi relativi a PyTorch/XLA
Se imposti le variabili facoltative per il debug nella sezione precedente,
il profilo del modello verrà archiviato nella posizione specificata dalla
variabile PROFILE_LOGDIR
. Puoi estrarre il file xplane.pb
memorizzato in questa posizione e utilizzare tensorboard
per visualizzare i profili nel browser seguendo le istruzioni di TensorBoard. Se il rendimento di PyTorch/XLA non è quello previsto, consulta la guida alla risoluzione dei problemi, che contiene suggerimenti per il debug, il profiling e l'ottimizzazione del modello.
Addestramento di DLRM DCN v2 su v6e
Questo tutorial mostra come addestrare il modello DLRM DCN v2 su TPU v6e. Devi eseguire il provisioning di una TPU v6e con 64, 128 o 256 chip.
Se esegui l'operazione su più host, reimposta tpu-runtime
con la versione di TensorFlow appropriata eseguendo il seguente comando. Se esegui l'operazione su un singolo host, non è necessario eseguire i due comandi riportati di seguito.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Accedi tramite SSH a worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Imposta il nome della TPU
export TPU_NAME=${TPU_NAME}
Esegui DLRM v2
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Esegui script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
I seguenti flag sono necessari per eseguire i carichi di lavoro dei consigli (DLRM DCN):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Risultati del benchmarking
La sezione seguente contiene i risultati del benchmarking per DLRM DCN v2 e MaxDiffusion su v6e.
DLRM DCN v2
Lo script di addestramento DLRM DCN v2 è stato eseguito su scale diverse. Consulta le portate nella tabella seguente.
v6e-64 | v6e-128 | v6e-256 | |
Passaggi di addestramento | 7000 | 7000 | 7000 |
Dimensione del batch globale | 131072 | 262144 | 524288 |
Velocità effettiva (esempi/sec) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Abbiamo eseguito lo script di addestramento per MaxDiffusion su una v6e-4, una v6e-16 e una 2xv6e-16. Consulta le portate nella tabella seguente.
v6e-4 | v6e-16 | Due v6e-16 | |
Passaggi di addestramento | 0,069 | 0,073 | 0,13 |
Dimensione del batch globale | 8 | 32 | 64 |
Velocità effettiva (esempi/sec) | 115,9 | 438,4 | 492,3 |
Pianificazione della raccolta
Trillium (v6e) include una nuova funzionalità chiamata "pianificazione delle raccolte". Questa funzionalità offre un modo per gestire più slice TPU che eseguono un singolo caricamento di lavoro di inferenza su un host sia su GKE che sull'API Cloud TPU. Raggruppare questi slice in una raccolta semplifica la regolazione del numero di repliche in base alla domanda. Gli aggiornamenti software vengono controllati attentamente per garantire che una parte degli slice all'interno della raccolta sia sempre disponibile per gestire il traffico in entrata.
Per ulteriori informazioni sull'utilizzo della pianificazione delle raccolte con GKE, consulta la documentazione di GKE.
La funzionalità di pianificazione della raccolta si applica solo alla versione 6e.
Utilizzare la pianificazione delle raccolte dall'API Cloud TPU
Una raccolta a host singolo nell'API Cloud TPU è una risorsa in coda su cui è impostato un flag speciale (--workload-type = availability-optimized
) per indicare all'infrastruttura di base che deve essere utilizzata per l'erogazione dei carichi di lavoro.
Il seguente comando esegue il provisioning di una raccolta con un solo host utilizzando l'API Cloud TPU:
gcloud alpha compute tpus queued-resources create my-collection \ --project=$PROJECT_ID \ --zone=${ZONE} \ --accelerator-type $ACCELERATOR_TYPE \ --node-count ${NODE_COUNT} \ --workload-type=availability-optimized
Monitoraggio e profilazione
Cloud TPU v6e supporta il monitoraggio e la profilazione utilizzando gli stessi metodi delle generazioni precedenti di Cloud TPU. Per ulteriori informazioni sul monitoraggio, consulta Monitorare le VM TPU.