Usar modelos abiertos de Gemma con Dataflow

Gemma es una familia de modelos abiertos, ligeros y de vanguardia creados a partir de la investigación y la tecnología que se utilizaron para crear los modelos Gemini. Puedes usar modelos Gemma en tus pipelines de inferencia de Apache Beam. El término peso abierto significa que se han publicado los parámetros o pesos preentrenados de un modelo. No se proporcionan detalles como el conjunto de datos original, la arquitectura del modelo y el código de entrenamiento.

Casos prácticos

Puedes usar modelos Gemma con Dataflow para hacer análisis de sentimiento. Con Dataflow y los modelos de Gemma, puedes procesar eventos, como reseñas de clientes, a medida que se producen. Analiza las reseñas con el modelo y, a continuación, genera recomendaciones. Si combinas Gemma con Apache Beam, podrás completar este flujo de trabajo sin problemas.

Compatibilidad y limitaciones

Los modelos abiertos de Gemma son compatibles con Apache Beam y Dataflow, y cumplen los siguientes requisitos:

  • Disponible para flujos de procesamiento por lotes y de streaming que usen las versiones 2.46.0 y posteriores del SDK de Apache Beam para Python.
  • Las tareas de Dataflow deben usar Runner v2.
  • Las tareas de Dataflow deben usar GPUs. Para ver una lista de los tipos de GPU compatibles con Dataflow, consulta la sección Disponibilidad. Se recomiendan los tipos de GPU T4 y L4.
  • El modelo debe descargarse y guardarse en formato de archivo .keras.
  • Se recomienda usar el gestor de modelos de TensorFlow, pero no es obligatorio.

Requisitos previos

  • Accede a los modelos de Gemma a través de Kaggle.
  • Rellena el formulario de consentimiento y acepta los términos y condiciones.
  • Descarga el modelo de Gemma. Guárdalo en formato de archivo .keras en una ubicación a la que pueda acceder tu trabajo de Dataflow, como un segmento de Cloud Storage. Cuando especifiques un valor para la variable de ruta del modelo, usa la ruta a esta ubicación de almacenamiento.
  • Para ejecutar tu trabajo en Dataflow, crea una imagen de contenedor personalizada. Este paso permite ejecutar la canalización con GPUs en el servicio Dataflow.

Usar Gemma en tu flujo de trabajo

Para usar un modelo de Gemma en tu canalización de Apache Beam, sigue estos pasos.

  1. En tu código de Apache Beam, después de importar las dependencias de tu canalización, incluye una ruta a tu modelo guardado:

    model_path = "MODEL_PATH"
    

    Sustituye MODEL_PATH por la ruta donde guardaste el modelo descargado. Por ejemplo, si guardas tu modelo en un segmento de Cloud Storage, la ruta tendrá el formato gs://STORAGE_PATH/FILENAME.keras.

  2. La implementación de Keras de los modelos Gemma tiene un método generate() que genera texto a partir de una petición. Para transferir elementos al método generate(), usa una función de inferencia personalizada.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Ejecuta tu flujo de trabajo especificando la ruta al modelo entrenado. En este ejemplo se usa un controlador de modelos de TensorFlow.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

Siguientes pasos