monkeypatch.transformers.trainer_loss_calc

monkeypatch.transformers.trainer_loss_calc

Module for patching transformers Trainer loss calculation to use nanmean.

This is needed for context parallelism since chunks of the input sequences may be fully masked and return NaNs in the loss calculation.

Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with the other evaluation_loop patch because we can’t patch the same code twice without raising an OSError.

Functions

Name Description
patch_evaluation_loop Patch the evaluation_loop method.
patch_maybe_log_save_evaluate Patch the _maybe_log_save_evaluate method.

patch_evaluation_loop

monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop()

Patch the evaluation_loop method.

patch_maybe_log_save_evaluate

monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate()

Patch the _maybe_log_save_evaluate method.