monkeypatch.loss.chunked
monkeypatch.loss.chunked
chunked ce loss
Classes
| Name | Description |
|---|---|
| CEWithChunkedOutputLoss | Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time. |
CEWithChunkedOutputLoss
monkeypatch.loss.chunked.CEWithChunkedOutputLoss(
num_output_chunks=8,
ignore_index=-100,
)Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
Methods
| Name | Description |
|---|---|
| compute_cross_entropy | Upcast logits to fp32 and compute cross entropy loss. |
| forward |
compute_cross_entropy
monkeypatch.loss.chunked.CEWithChunkedOutputLoss.compute_cross_entropy(
logits,
labels,
normalize=True,
)Upcast logits to fp32 and compute cross entropy loss.
forward
monkeypatch.loss.chunked.CEWithChunkedOutputLoss.forward(
logits,
labels,
reduction='sum',
)Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| logits | List[torch.Tensor] | List of chunked logits of length self.num_output_chunks, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size). |
required |
| labels | torch.Tensor | Ground truth labels of shape (batch_size, num_tokens). |
required |
| reduction | str | The reduction to apply to the output. | 'sum' |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | torch.Tensor: Cross entropy loss of shape (1,). |