monkeypatch.loss.chunked

monkeypatch.loss.chunked

chunked ce loss

Classes

Name Description
CEWithChunkedOutputLoss Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.

CEWithChunkedOutputLoss

monkeypatch.loss.chunked.CEWithChunkedOutputLoss(
    num_output_chunks=8,
    ignore_index=-100,
)

Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.

For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390

Methods

Name Description
compute_cross_entropy Upcast logits to fp32 and compute cross entropy loss.
forward
compute_cross_entropy
monkeypatch.loss.chunked.CEWithChunkedOutputLoss.compute_cross_entropy(
    logits,
    labels,
    normalize=True,
)

Upcast logits to fp32 and compute cross entropy loss.

forward
monkeypatch.loss.chunked.CEWithChunkedOutputLoss.forward(
    logits,
    labels,
    reduction='sum',
)
Parameters
Name Type Description Default
logits List[torch.Tensor] List of chunked logits of length self.num_output_chunks, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size). required
labels torch.Tensor Ground truth labels of shape (batch_size, num_tokens). required
reduction str The reduction to apply to the output. 'sum'
Returns
Name Type Description
torch.Tensor torch.Tensor: Cross entropy loss of shape (1,).