Usar modelos abertos do Gemma com o Dataflow

O Gemma é uma família de modelos abertos leves e de última geração, criados com base em pesquisas e tecnologias usadas para criar os modelos do Gemini. É possível usar modelos Gemma nos pipelines de inferência do Apache Beam. O termo peso aberto significa que os parâmetros pré-treinados de um modelo, ou pesos, são liberados. Detalhes como o conjunto de dados original, a arquitetura do modelo e o código de treinamento não são fornecidos.

Casos de uso

Você pode usar modelos Gemma com o Dataflow para análise de sentimento. Com o Dataflow e os modelos Gemma, é possível processar eventos, como avaliações de clientes, conforme eles chegam. Passe as revisões pelo modelo para analisá-las e gerar recomendações. Ao combinar o Gemma com o Apache Beam, é possível concluir esse fluxo de trabalho sem problemas.

Suporte e limitações

Os modelos abertos do Gemma são compatíveis com o Apache Beam e o Dataflow com os seguintes requisitos:

  • Disponível para pipelines de lote e streaming que usam o SDK do Apache Beam para Python na versão 2.46.0 e mais recentes.
  • Os jobs do Dataflow precisam usar o Runner v2.
  • Os jobs do Dataflow precisam usar GPUs. Para uma lista de tipos de GPU compatíveis com o Dataflow, consulte Disponibilidade. Os tipos de GPU T4 e L4 são recomendados.
  • É preciso fazer o download do modelo e salvá-lo no formato de arquivo .keras.
  • O gerenciador de modelos do TensorFlow é recomendado, mas não obrigatório.

Pré-requisitos

  • Acesse modelos Gemma por meio do Kaggle (link em inglês).
  • Preencha o formulário de consentimento e aceite os Termos e Condições.
  • Faça o download do modelo Gemma. Salve-o no formato de arquivo .keras em um local que seu job do Dataflow possa acessar, como um bucket do Cloud Storage. Ao especificar um valor para a variável de caminho do modelo, use o caminho para esse local de armazenamento.
  • Para executar o job no Dataflow, crie uma imagem de contêiner personalizada. Esta etapa possibilita a execução do pipeline com GPUs no serviço Dataflow.

Usar o Gemma no pipeline

Para usar um modelo Gemma no pipeline do Apache Beam, siga estas etapas.

  1. No código do Apache Beam, depois de importar as dependências do pipeline, inclua um caminho para o modelo salvo:

    model_path = "MODEL_PATH"
    

    Substitua MODEL_PATH pelo caminho em que você salvou o modelo salvo. Por exemplo, se você salvar o modelo em um bucket do Cloud Storage, o caminho terá o formato gs://STORAGE_PATH/FILENAME.keras.

  2. A implementação do Keras dos modelos Gemma tem um método generate() que gera texto com base em um comando. Para transmitir elementos ao método generate(), use uma função de inferência personalizada.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Executar o pipeline especificando o caminho para o modelo treinado; Este exemplo usa um gerenciador de modelos do TensorFlow.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

A seguir