kernels.geglu

kernels.geglu

Module for definition of GEGLU Triton kernels.

See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).

Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.

Functions

Name Description
geglu_backward GEGLU backward pass using in-place operations.
geglu_forward GEGLU forward pass.

geglu_backward

kernels.geglu.geglu_backward(grad_output, gate, up)

GEGLU backward pass using in-place operations.

Parameters

Name Type Description Default
grad_output torch.Tensor Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim]. required
gate torch.Tensor Gate tensor from forward pass, shape [batch, seq_len, hidden_dim]. required
up torch.Tensor Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim]. required

Returns

Name Type Description
tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple containing: - GEGLU activation output (h) - Gradient with respect to gate (grad_gate) - Gradient with respect to up (grad_up)

Note

This function modifies its input tensors in-place to store results.

geglu_forward

kernels.geglu.geglu_forward(gate, up)

GEGLU forward pass.

Parameters

Name Type Description Default
gate torch.Tensor Input gate tensor of shape [batch, seq_len, hidden_dim]. required
up torch.Tensor Up-projection tensor of shape [batch, seq_len, hidden_dim]. required

Returns

Name Type Description
torch.Tensor torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].