A step-by-step guide to fine-tuning MedGemma for breast tumor classification

Shir Meir Lador
Head of AI, Product DevRel
Artificial intelligence (AI) is revolutionizing healthcare, but how do you take a powerful, general-purpose AI model and teach it the specialized skills of a pathologist? This journey from prototype to production often begins in a notebook, which is exactly where we'll start.
In this guide, we'll take the crucial first step. We'll walk through the complete process of fine-tuning the Gemma 3 variant MedGemma. MedGemma is Google's family of open models for the medical community, to classify breast cancer histopathology images. We're using the full precision MedGemma model because that's what you'll need in order to get maximum performance for many clinical tasks. If you're concerned about compute costs, you can quantize and fine-tune by using MedGemma's pre-configured fine-tuning notebook instead.
To complete our first step, we'll use the Finetune Notebook. The notebook provides you with all of the code and a step-by-step explanation of the process, so it's the perfect environment for experimentation. I'll also share the key insights that I learned along the way, including a critical choice in data types that made all the difference.
After we've perfected our model in this prototyping phase, we'll be ready for the next step. In an upcoming post, we'll show you how to take this exact workflow and move it to a scalable, production-ready environment using Cloud Run jobs.
Setting the stage: Our goal, model, and data
Before we get to the code, let's set the stage. Our goal is to classify microscope images of breast tissue into one of eight categories: four benign (non-cancerous) and four malignant (cancerous). This type of classification represents one of many crucial tasks that pathologists perform in order to make an accurate diagnosis, and we have a great set of tools for the job.
We'll be using MedGemma, a powerful family of open models from Google that's built on the same research and technology that powers our Gemini models. What makes MedGemma special is that it isn't just a general model: it's been specifically tuned for the medical domain.
The MedGemma vision component, MedSigLIP, was pre-trained on a vast amount of de-identified medical imagery, including the exact type of histopathology slides that we're using. If you don't need the predictive power of MedGemma, you can use MedSigLIP alone as a more cost-effective option for predictive tasks like image classification. There are multiple MedSigLIP tutorial notebooks that you can use for fine-tuning.
The MedGemma language component was also trained on a diverse set of medical texts, making the google/medgemma-4b-it version that we're using perfect for following our text-based prompts. Google provides MedGemma as a strong foundation, but it requires fine-tuning for specific use cases—which is exactly what we're about to do.
To train our model, we'll use the Breast Cancer Histopathological Image Classification (BreakHis) dataset. The BreakHis dataset is a public collection of thousands of microscope images of breast tumor tissue that was collected from 82 patients using different magnifying factors (40X, 100X, 200X, and 400X). The dataset is publicly available for non-commercial research and it's detailed in the paper: F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.1
Handling a 4-billion parameter model requires a capable GPU, so I used an NVIDIA A100 with 40 GB of VRAM on Vertex AI Workbench. This GPU has the necessary power, and it also features NVIDIA Tensor Cores that excel with modern data formats, which we'll leverage for faster training. In an upcoming post, we'll explain how to calculate the VRAM that's required for your fine tuning.
My float16 disaster: A crucial lesson in stability
My first attempt to load the model used the common float16 data type to save memory. It failed spectacularly. The model's outputs were complete garbage, and a quick debugging check revealed that every internal value had collapsed into NaN (Not a Number).
The culprit was a classic numerical overflow.
To understand why, you need to know the critical difference between these 16-bit formats:
-
float16 (FP16): Has a tiny numerical range. It can't represent any number that's greater than 65,504. During the millions of calculations in a transformer, intermediate values can easily exceed this limit, causing an overflow that creates a NaN. When a NaN appears, it contaminates every subsequent calculation.
-
bfloat16 (BF16): This format, developed at Google Brain, makes a crucial trade-off. It sacrifices a little bit of precision to maintain the same massive numerical range as the full 32-bit float32 format.
The bfloat16 massive range prevents overflows, which keeps the training process stable. The fix was a simple one-line change, but it was based on this critical concept.
The successful code:
Lesson learned: For fine-tuning large models, always prefer bfloat16 for its stability. It's a small change that saves you from a world of NaN-related headaches.
The code walkthrough: A step-by-step guide
Now, let's get to the code. I'll break down my Finetune Notebook into clear, logical steps.
Step 1: Setup and installations
First, you need to install the necessary libraries from the Hugging Face ecosystem and log into your account to download the model.
Hugging Face authentication and and the recommended approach to handle your secrets
⚠️ Important security note: You should never hardcode secrets like API keys or tokens directly into your code or notebooks, especially in a production environment. This practice is insecure and it creates a significant security risk.
In Vertex AI Workbench, the most secure and enterprise-grade approach to handle secrets (like your Hugging Face token) is to use Google Cloud's Secret Manger.
If you're just experimenting and you don't want to set up Secret Manager yet, you can use the interactive login widget. The widget saves the token temporarily in the instance's file system.
In our upcoming post, where we move this process to Cloud Run Jobs, we'll show you the correct and secure way to handle this token by using Secret Manager.
Step 2: Load and prepare the dataset
Next, we download the BreakHis dataset from Kaggle using the kagglehub library. This dataset includes a Folds.csv file, which outlines how the data is split for experiments. The original study used 5-fold cross-validation, but to keep the training time manageable for this demonstration, we'll focus on Fold 1 and we'll only use images with 100X magnification. You can explore using other folds and magnifications for more extensive experiments.
Step 2.1: Balance the dataset
The initial train and test splits for the 100X magnification show an imbalance between benign and malignant classes. To address this, we'll undersample the majority class in both the training and testing sets in order to create balanced datasets with a 50/50 distribution.
Step 2.2: Create a Hugging Face dataset
We're converting our data into the Hugging Face datasets format because it's the easiest way to work with the SFTTrainer from their Transformers library. This format is optimized for handling large datasets, especially images, because it can load them efficiently when needed. And it gives us handy tools for preprocessing, like applying our formatting function to all examples.
Step 3: Prompt engineering
This step is where we tell the model what we want it to do. We create a clear, structured prompt that instructs the model to analyze an image and to return only the number that corresponds to a class. This prompt makes the output simple and easy to parse. We then map this format across our entire dataset.
Step 4: Load the model and processor
Here, we load the MedGemma model and its associated processor. The processor is a handy tool that prepares both the images and text for the model. We'll also make two key parameter choices for efficiency:
-
torch_dtype=torch.bfloat16: As we mentioned earlier, this format ensures numerical stability. -
attn_implementation="sdpa": Scaled dot product attention is a highly optimized attention mechanism that's available in PyTorch 2.0. Think of this mechanism as telling the model to use a super-fast, built-in engine for its most important calculation. It speeds up training and inference, and it can even automatically use more advanced backends like FlashAttention if your hardware supports it.
Step 5: Evaluate the baseline model
Before we invest time and compute in fine-tuning, let's see how the pre-trained model performs on its own. This step gives us a baseline to measure our improvement against.
The performance of the baseline model was evaluated on both 8-class and binary (benign/malignant) classification:
-
8-Class accuracy: 32.6%
-
8-Class F1 score (weighted): 0.241
-
Binary accuracy: 59.6%
-
Binary F1 score (malignant): 0.639
This output shows that the model performs better than random chance (12.5%), but there's significant room for improvement, especially in the fine-grained 8-class classification.
A quick detour: Few-shot learning vs. fine-tuning
Before we start training, it's worth asking: is fine-tuning the only way? Another popular technique is few-shot learning.
Few-shot learning is like giving a smart student a few examples of a new math problem right before a test. You aren't re-teaching them algebra, you're just showing them the specific pattern you want them to follow by providing examples directly in the prompt. This is a powerful technique, especially when you're using a closed model through an API where you can't access the internal weights.
So why did we choose fine-tuning?
-
We can host the model: Because MedGemma is an open model, we have direct access to its architecture. This access lets us perform fine-tuning to create a new, permanently updated version of the model.
-
We have a good dataset: Fine-tuning lets the model learn the deep, underlying patterns in our hundreds of training images far more effectively than just showing it a few examples in a prompt.
In short, fine-tuning creates a true specialist model for our task, which is exactly what we want.
Step 6: Configure and run fine-tuning with LoRA
This is the main event! We'll use Low-Rank Adaptation (LoRA), which is much faster and more memory-efficient than traditional fine-tuning. LoRA works by freezing the original model weights and training only a tiny set of new adapter weights. Here's a breakdown of our parameter choices:
-
r=8: The LoRA rank. A lower rank means fewer trainable parameters, which is faster but less expressive. A higher rank has more capacity, but risks overfitting on a small dataset. Rank 8 is a great starting point that balances performance and efficiency. -
lora_alpha=16: A scaling factor for the LoRA weights. A common rule of thumb is to set it to twice the rank (2 × r). -
lora_dropout=0.1: A regularization technique. It randomly deactivates some LoRA neurons during training to prevent the model from becoming overly specialized and failing to generalize.
The training took about 80 minutes on the A100 GPU with VRAM 40 GB. The results looked promising, with the validation loss steadily decreasing.
Important (time saving!) tip: If your training gets interrupted for any reason (like a connection issue or exceeding resource limits), you can resume the training process from a saved checkpoint by using the resume_from_checkpoint argument in trainer.train(). Checkpoints can save you valuable time because they're saved at every save_steps interval as defined in TrainingArguments.
Step 7: The final verdict - evaluating our fine-tuned model
After training, it's time for the moment of truth. We'll load our new LoRA adapter weights, merge them with the base model, and then run the same evaluation that we ran for the baseline.
Final results
So, how did the fine tuning impact performance? Let's look at the numbers for 8-class accuracy and macro F1.
The results are great! After fine-tuning, we see a dramatic improvement:
-
8-Class: Accuracy jumped from 32.6% to 87.2% (+54.6%) and F1 from 0.241 to 0.865.
-
Binary: Accuracy increased from 59.6% to 99.0% (+39.4%) and F1 from 0.639 to 0.991.
This project shows the incredible power of fine-tuning modern foundation models. We took a generalist AI that was already pre-trained on relevant medical data, gave it a small, specialized dataset, and taught it a new skill with remarkable efficiency. The journey from a generic model to a specialized classifier is more accessible than ever, opening up exciting possibilities for AI in medicine and beyond.
All of the information is available in the Finetune Notebook. You can run it in with a GPU instance on Vertex AI Workbench.
Want to take it to production? Don't forget to catch the upcoming post, which shows you how to bring the fine tuning and evaluation to Cloud Run jobs.
I hope this guide was helpful. Happy coding!
Special thanks to Fereshteh Mahvar and Dave Steiner from the MedGemma team for their helpful review and feedback on this post.
1 IEEE Transactions on Biomedical Engineering, vol. 63, no. 7, pp. 1455-1462, 2016


