Usa modelos abiertos de Gemma con Dataflow

Gemma es una familia de modelos abiertos ligeros y de vanguardia creados a partir de la investigación y la tecnología que se usan para crear los modelos de Gemini. Puedes usar modelos de Gemma en tus canalizaciones de inferencia de Apache Beam. El término peso abierto significa que se lanzan los parámetros o pesos entrenados previamente de un modelo. No se proporcionan detalles como el conjunto de datos original, la arquitectura del modelo y el código de entrenamiento.

Casos de uso

Puedes usar modelos de Gemma con Dataflow para el análisis de opiniones. Con Dataflow y los modelos de Gemma, puedes procesar eventos, como las opiniones de los clientes, a medida que llegan. Ejecuta las revisiones a través del modelo para analizarlas y, luego, generar recomendaciones. Si combinas Gemma con Apache Beam, puedes completar este flujo de trabajo sin problemas.

Asistencia y limitaciones

Los modelos abiertos de Gemma son compatibles con Apache Beam y Dataflow con los siguientes requisitos:

  • Está disponible para canalizaciones por lotes y de transmisión que usan las versiones 2.46.0 y posteriores del SDK de Apache Beam para Python.
  • Los trabajos de Dataflow deben usar Runner v2.
  • Los trabajos de Dataflow deben usar GPUs. Para obtener una lista de los tipos de GPU compatibles con Dataflow, consulta Disponibilidad. Se recomiendan los tipos de GPU T4 y L4.
  • El modelo debe descargarse y guardarse en el formato de archivo .keras.
  • Se recomienda usar el controlador de modelos de TensorFlow, pero no es obligatorio.

Requisitos previos

  • Accede a los modelos de Gemma a través de Kaggle.
  • Completa el formulario de consentimiento y acepta los Términos y Condiciones.
  • Descarga el modelo de Gemma. Guárdalo en el formato de archivo .keras en una ubicación a la que pueda acceder el trabajo de Dataflow, como un bucket de Cloud Storage. Cuando especifiques un valor para la variable de ruta de acceso del modelo, usa la ruta a esta ubicación de almacenamiento.
  • Para ejecutar tu trabajo en Dataflow, crea una imagen de contenedor personalizada. Este paso permite ejecutar la canalización con GPU en el servicio de Dataflow.

Usa Gemma en tu canalización

Para usar un modelo de Gemma en tu canalización de Apache Beam, sigue estos pasos.

  1. En tu código de Apache Beam, después de importar tus dependencias de canalización, incluye una ruta de acceso a tu modelo guardado:

    model_path = "MODEL_PATH"
    

    Reemplaza MODEL_PATH por la ruta de acceso en la que guardaste el modelo descargado. Por ejemplo, si guardas tu modelo en un bucket de Cloud Storage, la ruta tiene el formato gs://STORAGE_PATH/FILENAME.keras.

  2. La implementación de Keras de los modelos de Gemma tiene un método generate() que genera texto basado en una instrucción. Para pasar elementos al método generate(), usa una función de inferencia personalizada.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Ejecuta tu canalización y especifica la ruta de acceso al modelo entrenado. En este ejemplo, se usa un controlador de modelos de TensorFlow.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

¿Qué sigue?