monkeypatch.gemma4_loss_kwargs
monkeypatch.gemma4_loss_kwargs
Flip accepts_loss_kwargs to True on Gemma 4 (Unified) ForConditionalGeneration.
They inherit accepts_loss_kwargs = False from PaliGemma (whose loss filtered
logits/labels by attention_mask). Gemma 4’s loss is the stock ForCausalLMLoss
with no such filtering, so the flag wrongly makes the Trainer withhold
num_items_in_batch and mis-normalize the loss under gradient accumulation.
Install before Trainer.__init__ reads the flag.
Functions
| Name | Description |
|---|---|
| patch_gemma4_accepts_loss_kwargs | Set accepts_loss_kwargs=True on Gemma 4 (Unified) ForConditionalGeneration. |
patch_gemma4_accepts_loss_kwargs
monkeypatch.gemma4_loss_kwargs.patch_gemma4_accepts_loss_kwargs()Set accepts_loss_kwargs=True on Gemma 4 (Unified) ForConditionalGeneration.