monkeypatch.loss.eaft

monkeypatch.loss.eaft

eaft (entropy-aware focal training) loss implementation weights examples by entropy approximation from top-k logits

Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py

Functions

Name Description
eaft_loss compute eaft loss with entropy weighting

eaft_loss

monkeypatch.loss.eaft.eaft_loss(
    outputs,
    labels,
    num_items_in_batch=None,
    alpha=1.0,
    k=20,
)

compute eaft loss with entropy weighting

Parameters

Name Type Description Default
outputs model outputs containing logits required
labels target labels for computing loss required
num_items_in_batch for sample packing support None
alpha exponent for entropy weighting (default 1.0) 1.0
k number of top logits for entropy approximation (default 20) 20