integrations.base

integrations.base

Base class for all plugins.

A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.

To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.

Classes

Name Description
BaseOptimizerFactory Base class for factories to create custom optimizers
BasePlugin Base class for all plugins. Defines the interface for plugin methods.
PluginManager The PluginManager class is responsible for loading and managing plugins. It

BaseOptimizerFactory

integrations.base.BaseOptimizerFactory()

Base class for factories to create custom optimizers

BasePlugin

integrations.base.BasePlugin()

Base class for all plugins. Defines the interface for plugin methods.

A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features.

To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.

Note

Plugin methods include: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.

Methods

Name Description
add_callbacks_post_trainer Adds callbacks to the trainer after creating the trainer. This is useful for
add_callbacks_pre_trainer Set up callbacks before creating the trainer.
create_lr_scheduler Creates and returns a learning rate scheduler.
create_optimizer Creates and returns an optimizer for training.
get_input_args Returns a pydantic model for the plugin’s input arguments.
get_trainer_cls Returns a custom class for the trainer.
load_datasets Loads and preprocesses the dataset for training.
post_lora_load Performs actions after LoRA weights are loaded.
post_model_build Performs actions after the model is built/loaded, but before any adapters are applied.
post_model_load Performs actions after the model is loaded.
post_train Performs actions after training is complete.
post_train_unload Performs actions after training is complete and the model is unloaded.
post_trainer_create Performs actions after the trainer is created.
pre_lora_load Performs actions before LoRA weights are loaded.
pre_model_load Performs actions before the model is loaded.
register Registers the plugin with the given configuration.
add_callbacks_post_trainer
integrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)

Adds callbacks to the trainer after creating the trainer. This is useful for callbacks that require access to the model or trainer.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
trainer Trainer The trainer object for training. required
Returns
Name Type Description
list[Callable] A list of callback functions to be added
add_callbacks_pre_trainer
integrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)

Set up callbacks before creating the trainer.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
model PreTrainedModel The loaded model. required
Returns
Name Type Description
list[Callable] A list of callback functions to be added to the TrainingArgs.
create_lr_scheduler
integrations.base.BasePlugin.create_lr_scheduler(
    cfg,
    trainer,
    optimizer,
    num_training_steps,
)

Creates and returns a learning rate scheduler.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
trainer Trainer The trainer object for training. required
optimizer Optimizer The optimizer for training. required
num_training_steps int Total number of training steps required
Returns
Name Type Description
LRScheduler | None The created learning rate scheduler.
create_optimizer
integrations.base.BasePlugin.create_optimizer(cfg, trainer)

Creates and returns an optimizer for training.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
trainer Trainer The trainer object for training. required
Returns
Name Type Description
Optimizer | None The created optimizer.
get_input_args
integrations.base.BasePlugin.get_input_args()

Returns a pydantic model for the plugin’s input arguments.

get_trainer_cls
integrations.base.BasePlugin.get_trainer_cls(cfg)

Returns a custom class for the trainer.

Parameters
Name Type Description Default
cfg DictDefault The global axolotl configuration. required
Returns
Name Type Description
Trainer | None The first non-None trainer class returned by a plugin.
load_datasets
integrations.base.BasePlugin.load_datasets(cfg, preprocess=False)

Loads and preprocesses the dataset for training.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
preprocess bool Whether this is the preprocess step of the datasets. False
Returns
Name Type Description
dataset_meta Union['TrainDatasetMeta', None] The metadata for the training dataset.
post_lora_load
integrations.base.BasePlugin.post_lora_load(cfg, model)

Performs actions after LoRA weights are loaded.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
model PreTrainedModel | PeftModel The loaded model. required
post_model_build
integrations.base.BasePlugin.post_model_build(cfg, model)

Performs actions after the model is built/loaded, but before any adapters are applied.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
post_model_load
integrations.base.BasePlugin.post_model_load(cfg, model)

Performs actions after the model is loaded.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
model PreTrainedModel | PeftModel The loaded model. required
post_train
integrations.base.BasePlugin.post_train(cfg, model)

Performs actions after training is complete.

Parameters
Name Type Description Default
cfg DictDefault The axolotl configuration. required
model PreTrainedModel | PeftModel The loaded model. required
post_train_unload
integrations.base.BasePlugin.post_train_unload(cfg)

Performs actions after training is complete and the model is unloaded.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
post_trainer_create
integrations.base.BasePlugin.post_trainer_create(cfg, trainer)

Performs actions after the trainer is created.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
trainer Trainer The trainer object for training. required
pre_lora_load
integrations.base.BasePlugin.pre_lora_load(cfg, model)

Performs actions before LoRA weights are loaded.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
model PreTrainedModel The loaded model. required
pre_model_load
integrations.base.BasePlugin.pre_model_load(cfg)

Performs actions before the model is loaded.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required
register
integrations.base.BasePlugin.register(cfg)

Registers the plugin with the given configuration.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugin. required

PluginManager

integrations.base.PluginManager()

The PluginManager class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase.

Attributes

Name Type Description
plugins OrderedDict[str, BasePlugin] A list of loaded plugins.

Note

Key methods include: - get_instance(): Static method to get the singleton instance of PluginManager. - register(plugin_name: str): Registers a new plugin by its name. - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.

Methods

Name Description
add_callbacks_post_trainer Calls the add_callbacks_post_trainer method of all registered plugins.
add_callbacks_pre_trainer Calls the add_callbacks_pre_trainer method of all registered plugins.
create_lr_scheduler Calls the create_lr_scheduler method of all registered plugins and returns
create_optimizer Calls the create_optimizer method of all registered plugins and returns
get_input_args Returns a list of Pydantic classes for all registered plugins’ input arguments.’
get_instance Returns the singleton instance of PluginManager. If the instance doesn’t
get_trainer_cls Calls the get_trainer_cls method of all registered plugins and returns the
load_datasets Calls the load_datasets method of each registered plugin.
post_lora_load Calls the post_lora_load method of all registered plugins.
post_model_build Calls the post_model_build method of all registered plugins after the
post_model_load Calls the post_model_load method of all registered plugins after the model
post_train Calls the post_train method of all registered plugins.
post_train_unload Calls the post_train_unload method of all registered plugins.
post_trainer_create Calls the post_trainer_create method of all registered plugins.
pre_lora_load Calls the pre_lora_load method of all registered plugins.
pre_model_load Calls the pre_model_load method of all registered plugins.
register Registers a new plugin by its name.
add_callbacks_post_trainer
integrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)

Calls the add_callbacks_post_trainer method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
trainer Trainer The trainer object for training. required
Returns
Name Type Description
list[Callable] A list of callback functions to be added to the TrainingArgs.
add_callbacks_pre_trainer
integrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)

Calls the add_callbacks_pre_trainer method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel The loaded model. required
Returns
Name Type Description
list[Callable] A list of callback functions to be added to the TrainingArgs.
create_lr_scheduler
integrations.base.PluginManager.create_lr_scheduler(
    trainer,
    optimizer,
    num_training_steps,
)

Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.

Parameters
Name Type Description Default
trainer Trainer The trainer object for training. required
optimizer Optimizer The optimizer for training. required
Returns
Name Type Description
LRScheduler | None The created learning rate scheduler, or None if not found.
create_optimizer
integrations.base.PluginManager.create_optimizer(trainer)

Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

Parameters
Name Type Description Default
trainer Trainer The trainer object for training. required
Returns
Name Type Description
Optimizer | None The created optimizer, or None if none was found.
get_input_args
integrations.base.PluginManager.get_input_args()

Returns a list of Pydantic classes for all registered plugins’ input arguments.’

Returns
Name Type Description
list[str] A list of Pydantic classes for all registered plugins’ input arguments.’
get_instance
integrations.base.PluginManager.get_instance()

Returns the singleton instance of PluginManager. If the instance doesn’t exist, it creates a new one.

get_trainer_cls
integrations.base.PluginManager.get_trainer_cls(cfg)

Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
Returns
Name Type Description
Trainer | None The first non-None trainer class returned by a plugin.
load_datasets
integrations.base.PluginManager.load_datasets(cfg, preprocess=False)

Calls the load_datasets method of each registered plugin.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
preprocess bool Whether this is preprocess step of the datasets. False
Returns
Name Type Description
Union['TrainDatasetMeta', None] The dataset metadata loaded from all registered plugins.
post_lora_load
integrations.base.PluginManager.post_lora_load(cfg, model)

Calls the post_lora_load method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel | PeftModel The loaded model. required
post_model_build
integrations.base.PluginManager.post_model_build(cfg, model)

Calls the post_model_build method of all registered plugins after the model has been built / loaded, but before any adapters have been applied.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel The loaded model. required
post_model_load
integrations.base.PluginManager.post_model_load(cfg, model)

Calls the post_model_load method of all registered plugins after the model has been loaded inclusive of any adapters.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel | PeftModel The loaded model. required
post_train
integrations.base.PluginManager.post_train(cfg, model)

Calls the post_train method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel | PeftModel The loaded model. required
post_train_unload
integrations.base.PluginManager.post_train_unload(cfg)

Calls the post_train_unload method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
post_trainer_create
integrations.base.PluginManager.post_trainer_create(cfg, trainer)

Calls the post_trainer_create method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
trainer Trainer The trainer object for training. required
pre_lora_load
integrations.base.PluginManager.pre_lora_load(cfg, model)

Calls the pre_lora_load method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
model PreTrainedModel The loaded model. required
pre_model_load
integrations.base.PluginManager.pre_model_load(cfg)

Calls the pre_model_load method of all registered plugins.

Parameters
Name Type Description Default
cfg DictDefault The configuration for the plugins. required
register
integrations.base.PluginManager.register(plugin_name)

Registers a new plugin by its name.

Parameters
Name Type Description Default
plugin_name str The name of the plugin to be registered. required
Raises
Name Type Description
ImportError If the plugin module cannot be imported.

Functions

Name Description
load_plugin Loads a plugin based on the given plugin name.

load_plugin

integrations.base.load_plugin(plugin_name)

Loads a plugin based on the given plugin name.

The plugin name should be in the format “module_name.class_name”. This function splits the plugin name into module and class, imports the module, retrieves the class from the module, and creates an instance of the class.

Parameters

Name Type Description Default
plugin_name str The name of the plugin to be loaded. The name should be in the format “module_name.class_name”. required

Returns

Name Type Description
BasePlugin An instance of the loaded plugin.

Raises

Name Type Description
ImportError If the plugin module cannot be imported.