monkeypatch.gemma4_loss_kwargs

monkeypatch.gemma4_loss_kwargs

Flip accepts_loss_kwargs to True on Gemma 4 (Unified) ForConditionalGeneration.

They inherit accepts_loss_kwargs = False from PaliGemma (whose loss filtered logits/labels by attention_mask). Gemma 4’s loss is the stock ForCausalLMLoss with no such filtering, so the flag wrongly makes the Trainer withhold num_items_in_batch and mis-normalize the loss under gradient accumulation. Install before Trainer.__init__ reads the flag.

Functions

Name Description
patch_gemma4_accepts_loss_kwargs Set accepts_loss_kwargs=True on Gemma 4 (Unified) ForConditionalGeneration.

patch_gemma4_accepts_loss_kwargs

monkeypatch.gemma4_loss_kwargs.patch_gemma4_accepts_loss_kwargs()

Set accepts_loss_kwargs=True on Gemma 4 (Unified) ForConditionalGeneration.