integrations.diffusion.trainer

integrations.diffusion.trainer

Custom trainer for diffusion LM training.

Classes

Name Description
DiffusionTrainer Custom trainer for diffusion LM training that overrides loss computation.

DiffusionTrainer

integrations.diffusion.trainer.DiffusionTrainer(*args, **kwargs)

Custom trainer for diffusion LM training that overrides loss computation.

Methods

Name Description
compute_loss Override compute_loss to use diffusion loss.
post_set_axolotl_cfg Set config for diffusion training.
compute_loss
integrations.diffusion.trainer.DiffusionTrainer.compute_loss(
    model,
    inputs,
    return_outputs=False,
    num_items_in_batch=None,
)

Override compute_loss to use diffusion loss.

post_set_axolotl_cfg
integrations.diffusion.trainer.DiffusionTrainer.post_set_axolotl_cfg()

Set config for diffusion training.