kernels.geglu
kernels.geglu
Module for definition of GEGLU Triton kernels.
See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).
Credit to unsloth
(https://unsloth.ai/) for inspiration for this implementation.
Functions
Name | Description |
---|---|
geglu_backward | GEGLU backward pass using in-place operations. |
geglu_forward | GEGLU forward pass. |
geglu_backward
kernels.geglu.geglu_backward(grad_output, gate, up)
GEGLU backward pass using in-place operations.
Parameters
Name | Type | Description | Default |
---|---|---|---|
grad_output | torch.Tensor | Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim] . |
required |
gate | torch.Tensor | Gate tensor from forward pass, shape [batch, seq_len, hidden_dim] . |
required |
up | torch.Tensor | Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim] . |
required |
Returns
Name | Type | Description |
---|---|---|
tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Tuple containing: - GEGLU activation output (h ) - Gradient with respect to gate (grad_gate ) - Gradient with respect to up (grad_up ) |
Note
This function modifies its input tensors in-place to store results.
geglu_forward
kernels.geglu.geglu_forward(gate, up)
GEGLU forward pass.
Parameters
Name | Type | Description | Default |
---|---|---|---|
gate | torch.Tensor | Input gate tensor of shape [batch, seq_len, hidden_dim]. | required |
up | torch.Tensor | Up-projection tensor of shape [batch, seq_len, hidden_dim]. | required |
Returns
Name | Type | Description |
---|---|---|
torch.Tensor | torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim]. |