将 Gemma 开放模型与 Dataflow 搭配使用

Gemma 是一系列先进的轻量级开放模型,基于用于创建 Gemini 模型的研究和技术构建而成。您可以在 Apache Beam 推理流水线中使用 Gemma 模型。术语“开放权重”表示模型的预训练参数(即权重)会被释放。不提供原始数据集、模型架构和训练代码等详细信息。

应用场景

您可以将 Gemma 模型与 Dataflow 配合使用,以进行情感分析。借助 Dataflow 和 Gemma 模型,您可以在事件(例如客户评价)到达时处理事件。通过模型运行评价以进行分析,然后生成建议。通过将 Gemma 与 Apache Beam 结合使用,您可以无缝完成此工作流。

支持和限制

Apache Beam 和 Dataflow 支持 Gemma 开放模型,不过具有以下要求:

  • 适用于使用 Apache Beam Python SDK 2.46.0 版及更高版本的批处理和流处理流水线。
  • Dataflow 作业必须使用 Runner v2
  • Dataflow 作业必须使用 GPU。如需查看 Dataflow 支持的 GPU 类型列表,请参阅可用性。建议使用 T4 和 L4 GPU 类型。
  • 模型必须以 .keras 文件格式进行下载并保存。
  • 建议使用 TensorFlow 模型处理程序,但这不是必需的。

前提条件

  • 通过 Kaggle 访问 Gemma 模型。
  • 填写同意书并接受条款及条件。
  • 下载 Gemma 模型。以 .keras 文件格式将其保存在 Dataflow 作业可以访问的位置,例如 Cloud Storage 存储桶。为模型路径变量指定值时,请使用此存储位置的路径。
  • 如需在 Dataflow 上运行作业,请创建自定义容器映像。此步骤让您可以在 Dataflow 服务上使用 GPU 运行流水线。

在流水线中使用 Gemma

如需在 Apache Beam 流水线中使用 Gemma 模型,请按以下步骤操作。

  1. 在 Apache Beam 代码中,导入流水线依赖项后,请添加指向已保存的模型的路径:

    model_path = "MODEL_PATH"
    

    MODEL_PATH 替换为您用于保存下载的模型的路径。例如,如果您将模型保存到 Cloud Storage 存储桶,则路径的格式为 gs://STORAGE_PATH/FILENAME.keras

  2. Gemma 模型的 Keras 实现具有一个 generate() 方法,该方法基于提示生成文本。如需将元素传递给 generate() 方法,请使用自定义推理函数。

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. 运行流水线,并指定经过训练的模型的路径。此示例使用 TensorFlow 模型处理程序。

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

后续步骤