Entrena Resnet50 en Cloud TPU con PyTorch


En este instructivo, se muestra cómo entrenar el modelo ResNet-50 en un dispositivo Cloud TPU con PyTorch. Puedes aplicar el mismo patrón a otros modelos de clasificación de imágenes optimizados con TPU que usen PyTorch y el conjunto de datos ImageNet.

El modelo de este instructivo se basa en Aprendizaje residual profundo para el reconocimiento de imágenes, que presentó por primera vez la arquitectura de la red residual (ResNet). En el instructivo, se usa la variante de 50 capas, ResNet-50, y se muestra cómo entrenar el modelo con PyTorch/XLA.

Objetivos

  • Preparar el conjunto de datos
  • Ejecutar el trabajo de entrenamiento
  • Verificar los resultados de salida

Costos

En este documento, usarás los siguientes componentes facturables de Google Cloud:

  • Compute Engine
  • Cloud TPU

Para generar una estimación de costos en función del uso previsto, usa la calculadora de precios. Es posible que los usuarios nuevos de Google Cloud califiquen para obtener una prueba gratuita.

Antes de comenzar

Antes de comenzar este instructivo, verifica que tu proyecto de Google Cloud esté configurado correctamente.

  1. Accede a tu cuenta de Google Cloud. Si eres nuevo en Google Cloud, crea una cuenta para evaluar el rendimiento de nuestros productos en situaciones reales. Los clientes nuevos también obtienen $300 en créditos gratuitos para ejecutar, probar y, además, implementar cargas de trabajo.
  2. En la página del selector de proyectos de la consola de Google Cloud, selecciona o crea un proyecto de Google Cloud.

    Ir al selector de proyectos

  3. Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.

  4. En la página del selector de proyectos de la consola de Google Cloud, selecciona o crea un proyecto de Google Cloud.

    Ir al selector de proyectos

  5. Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.

  6. En esta explicación, se usan componentes facturables de Google Cloud. Consulta la página de precios de Cloud TPU para calcular los costos. Asegúrate de limpiar los recursos que crees cuando hayas terminado de usarlos para evitar cargos innecesarios.

Crea una VM de TPU

  1. Abre una ventana de Cloud Shell.

    Abra Cloud Shell

  2. Crea una VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central2-b \
    --project=your-project
    
  3. Conéctate a la VM de TPU con SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central2-b
    
  4. Instala PyTorch/XLA en la VM de TPU:

    (vm)$ pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
    
  5. Clona el repositorio de GitHub PyTorch/XLA

    (vm)$ git clone --depth=1 --branch r2.2 https://github.com/pytorch/xla.git
    
  6. Ejecuta la secuencia de comandos de entrenamiento con datos falsos

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
    

Si puedes entrenar el modelo con datos falsos, puedes intentar entrenarlo con datos reales, como ImageNet. Si deseas obtener instrucciones para descargar ImageNet, consulta Cómo descargar ImageNet. En la secuencia de comandos de entrenamiento, la marca --datadir especifica la ubicación del conjunto de datos en el que se realizará el entrenamiento. En el siguiente comando, se supone que el conjunto de datos de ImageNet se encuentra en ~/imagenet.

   (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py  --datadir=~/imagenet --batch_size=256 --num_epochs=1
   

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

  1. Desconéctate de la VM de TPU:

    (vm) $ exit
    

    El mensaje ahora debería mostrar username@projectname, que indica que estás en Cloud Shell.

  2. Borra la VM de TPU.

    $ gcloud compute tpus tpu-vm delete resnet50-tutorial \
       --zone=us-central2-b
    

¿Qué sigue?

Prueba los siguientes colaboradores de PyTorch: