Run inference with a Gemma open model

Run in Google Colab View source on GitHub

Gemma is a family of lightweight, state-of-the art open models built from research and technology used to create the Gemini models. You can use Gemma models in your Apache Beam inference pipelines with the RunInference transform.

This notebook demonstrates how to load the preconfigured Gemma 2B model and then use it in your Apache Beam inference pipeline. The pipeline runs examples by using a built-in model handler and a custom inference function.

For more information about using RunInference, see Get started with AI/ML pipelines in the Apache Beam documentation.


Serving and using Gemma models requires a substantial amount of RAM. To run this example, we recommend that you use a notebook instance with GPUs. At a mimumum, use a machine that has the T4 GPU type. This configuration provides sufficient memory for running inference with a saved model.

Before you begin

  • To use a fine-tuned version of the model, follow the steps in Gemma fine-tuning.
  • For testing this workflow, we recommend using the instruction tuned model in your Apache Beam workflow. For example, if you use the Gemma 2B model in your pipeline, when you load the model, change the GemmaCausalLM.from_preset() argument from gemma_2b_en to gemma_instruct_2b_en. For more information, see Create a model in "Get started with Gemma using KerasNLP". For a list of models, see Gemma models.

Install Dependencies

To use the RunInference transform with the built-in TensorFlow model handler, install Apache Beam version 2.46.0 or later. The model class is contained in the Keras natural language processing (NLP) package versions 0.8.0 and later.

!pip install -q -U protobuf
!pip install -q -U apache_beam[gcp]
!pip install -q -U keras_nlp>=0.8.0
!pip install -q -U keras>3

# To use the newly installed versions, restart the runtime.

Authenticate with Kaggle

The pipeline defined here automatically pulls the model weights from Kaggle. First, accept the terms of use for Gemma models on the Keras Gemma page. Next, generate an API token by following the instructions in How to use Kaggle. Provide your username and token.

import kagglehub

VBox(children=(HTML(value='<center> <img\nsrc=\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

Import dependencies and provide a model preset

Use the following code to import dependencies.

Replace the value for the model_preset variable with the name of the Gemma preset to use. For example, to use the default English weights, use the value gemma_2b_en. This example uses the instruction-tuned preset gemma_instruct_2b_en. Optionally, to run the model at half-precision and reduce GPU memory usage, use Keras.

import numpy as np

import apache_beam as beam
import keras_nlp
import keras
from import utils
from import RunInference
from import TFModelHandlerNumpy
from apache_beam.options.pipeline_options import PipelineOptions

model_preset = "gemma_instruct_2b_en"
# Optionally set the model to run at half-precision
# (recommended for smaller GPUs)

Run the pipeline

To run the pipeline, use a custom model handler.

Provide a custom model handler

To simplify model loading, this notebook defines a custom model handler that loads the model by pulling the model weights directly from Kaggle presets. To customize the behavior of the handler, implement load_model, validate_inference_args, and share_model_across_processes. The Keras implementation of the Gemma models has a generate method that generates text based on a prompt. To route the prompts properly, use this function in the run_inference method.

# To load the model and perform the inference, define `GemmaModelHandler`.

from import ModelHandler
from import PredictionResult
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM

class GemmaModelHandler(ModelHandler[str,
    def __init__(
        model_name: str = "gemma_2b_en",
        """ Implementation of the ModelHandler interface for Gemma using text as input.

        Example Usage::

          pcoll | RunInference(GemmaModelHandler())

          model_name: The Gemma model preset. Default is gemma_2b_instruct_en.
        self._model_name = model_name
        self._env_vars = {}
    def share_model_across_processes(self)  -> bool:
        return True

    def load_model(self) -> GemmaCausalLM:
        """Loads and initializes a model for processing."""
        return keras_nlp.models.GemmaCausalLM.from_preset(self._model_name)

    def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
        """Validates the inference arguments."""
        for key, value in inference_args.items():
            if key != "max_length":
                raise ValueError(f"Invalid inference argument: {key}")

    def run_inference(
        batch: Sequence[str],
        model: GemmaCausalLM,
        inference_args: Optional[Dict[str, Any]] = None
    ) -> Iterable[PredictionResult]:
        """Runs inferences on a batch of text strings.

          batch: A sequence of examples as text strings.
          inference_args: Any additional arguments for an inference.

          An Iterable of type PredictionResult.
        # Loop each text string, and use a tuple to store the inference results.
        predictions = []
        for one_text in batch:
            result = model.generate(one_text, **inference_args)
        return utils._convert_to_result(batch, predictions, self._model_name)

Execute the pipeline

Use the following code to run the pipeline. The code includes the path to the trained TensorFlow model. This cell can take a few minutes to run, because the model is downloaded and then loaded onto the worker. This delay is a one-time cost per worker.

The max_length argument determines how long the response from Gemma is. The response includes your input, so the response length includes your input and the output. For longer prompts, use a larger maximum length. Longer lengths require more time to generate.

class FormatOutput(beam.DoFn):
  def process(self, element, *args, **kwargs):
    yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)

# Instantiate a NumPy array of string prompts for the model.
examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
# Specify the model handler, providing a path and the custom inference function.
model_handler = GemmaModelHandler(model_preset)
with beam.Pipeline() as p:
  _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
         | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
         | beam.ParDo(FormatOutput()) # Format the output.
         | beam.Map(print) # Print the formatted output.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
Input: Tell me the sentiment of the phrase 'I like pizza': , Output: Tell me the sentiment of the phrase 'I like pizza': 

The sentiment of the phrase "I like pizza" is positive. It expresses a personal