integrations.base
integrations.base
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
Classes
BaseOptimizerFactory
Base class for factories to create custom optimizers
BasePlugin
Base class for all plugins. Defines the interface for plugin methods.
PluginManager
The PluginManager
class is responsible for loading and managing plugins. It
BaseOptimizerFactory
integrations.base.BaseOptimizerFactory()
Base class for factories to create custom optimizers
BasePlugin
integrations.base.BasePlugin()
Base class for all plugins. Defines the interface for plugin methods.
A plugin is a reusable, modular, and self-contained piece of code that extends
the functionality of Axolotl. Plugins can be used to integrate third-party models,
modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied.
- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
- post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters.
- post_trainer_create(cfg, trainer): Performs actions after the trainer is
created.
- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler.
- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training.
- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training.
Methods
add_callbacks_post_trainer
integrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)
Adds callbacks to the trainer after creating the trainer. This is useful for
callbacks that require access to the model or trainer.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
trainer
Trainer
The trainer object for training.
required
Returns
list[Callable]
A list of callback functions to be added
add_callbacks_pre_trainer
integrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)
Set up callbacks before creating the trainer.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
model
PreTrainedModel
The loaded model.
required
Returns
list[Callable]
A list of callback functions to be added to the TrainingArgs
.
create_lr_scheduler
integrations.base.BasePlugin.create_lr_scheduler(
cfg,
trainer,
optimizer,
num_training_steps,
)
Creates and returns a learning rate scheduler.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
trainer
Trainer
The trainer object for training.
required
optimizer
Optimizer
The optimizer for training.
required
num_training_steps
int
Total number of training steps
required
Returns
LRScheduler | None
The created learning rate scheduler.
create_optimizer
integrations.base.BasePlugin.create_optimizer(cfg, trainer)
Creates and returns an optimizer for training.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
trainer
Trainer
The trainer object for training.
required
Returns
Optimizer | None
The created optimizer.
get_trainer_cls
integrations.base.BasePlugin.get_trainer_cls(cfg)
Returns a custom class for the trainer.
Parameters
cfg
DictDefault
The global axolotl configuration.
required
Returns
Trainer | None
The first non-None
trainer class returned by a plugin.
load_datasets
integrations.base.BasePlugin.load_datasets(cfg, preprocess= False )
Loads and preprocesses the dataset for training.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
preprocess
bool
Whether this is the preprocess step of the datasets.
False
Returns
dataset_meta
Union['TrainDatasetMeta', None]
The metadata for the training dataset.
post_lora_load
integrations.base.BasePlugin.post_lora_load(cfg, model)
Performs actions after LoRA weights are loaded.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_model_build
integrations.base.BasePlugin.post_model_build(cfg, model)
Performs actions after the model is built/loaded, but before any adapters are applied.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
post_model_load
integrations.base.BasePlugin.post_model_load(cfg, model)
Performs actions after the model is loaded.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_train
integrations.base.BasePlugin.post_train(cfg, model)
Performs actions after training is complete.
Parameters
cfg
DictDefault
The axolotl configuration.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_train_unload
integrations.base.BasePlugin.post_train_unload(cfg)
Performs actions after training is complete and the model is unloaded.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
post_trainer_create
integrations.base.BasePlugin.post_trainer_create(cfg, trainer)
Performs actions after the trainer is created.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
trainer
Trainer
The trainer object for training.
required
pre_lora_load
integrations.base.BasePlugin.pre_lora_load(cfg, model)
Performs actions before LoRA weights are loaded.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
model
PreTrainedModel
The loaded model.
required
pre_model_load
integrations.base.BasePlugin.pre_model_load(cfg)
Performs actions before the model is loaded.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
register
integrations.base.BasePlugin.register(cfg)
Registers the plugin with the given configuration.
Parameters
cfg
DictDefault
The configuration for the plugin.
required
PluginManager
integrations.base.PluginManager()
The PluginManager
class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
Attributes
plugins
OrderedDict[str, BasePlugin]
A list of loaded plugins.
Note
Key methods include:
- get_instance(): Static method to get the singleton instance of PluginManager
.
- register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Methods
add_callbacks_post_trainer
Calls the add_callbacks_post_trainer
method of all registered plugins.
add_callbacks_pre_trainer
Calls the add_callbacks_pre_trainer method of all registered plugins.
create_lr_scheduler
Calls the create_lr_scheduler
method of all registered plugins and returns
create_optimizer
Calls the create_optimizer
method of all registered plugins and returns
get_input_args
Returns a list of Pydantic classes for all registered plugins’ input arguments.’
get_instance
Returns the singleton instance of PluginManager. If the instance doesn’t
get_trainer_cls
Calls the get_trainer_cls
method of all registered plugins and returns the
load_datasets
Calls the load_datasets method of each registered plugin.
post_lora_load
Calls the post_lora_load
method of all registered plugins.
post_model_build
Calls the post_model_build
method of all registered plugins after the
post_model_load
Calls the post_model_load
method of all registered plugins after the model
post_train
Calls the post_train method of all registered plugins.
post_train_unload
Calls the post_train_unload method of all registered plugins.
post_trainer_create
Calls the post_trainer_create
method of all registered plugins.
pre_lora_load
Calls the pre_lora_load
method of all registered plugins.
pre_model_load
Calls the pre_model_load method of all registered plugins.
register
Registers a new plugin by its name.
add_callbacks_post_trainer
integrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)
Calls the add_callbacks_post_trainer
method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
trainer
Trainer
The trainer object for training.
required
Returns
list[Callable]
A list of callback functions to be added to the TrainingArgs
.
add_callbacks_pre_trainer
integrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)
Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel
The loaded model.
required
Returns
list[Callable]
A list of callback functions to be added to the TrainingArgs
.
create_lr_scheduler
integrations.base.PluginManager.create_lr_scheduler(
trainer,
optimizer,
num_training_steps,
)
Calls the create_lr_scheduler
method of all registered plugins and returns
the first non-None
scheduler.
Parameters
trainer
Trainer
The trainer object for training.
required
optimizer
Optimizer
The optimizer for training.
required
Returns
LRScheduler | None
The created learning rate scheduler, or None
if not found.
create_optimizer
integrations.base.PluginManager.create_optimizer(trainer)
Calls the create_optimizer
method of all registered plugins and returns
the first non-None
optimizer.
Parameters
trainer
Trainer
The trainer object for training.
required
Returns
Optimizer | None
The created optimizer, or None
if none was found.
get_instance
integrations.base.PluginManager.get_instance()
Returns the singleton instance of PluginManager. If the instance doesn’t
exist, it creates a new one.
get_trainer_cls
integrations.base.PluginManager.get_trainer_cls(cfg)
Calls the get_trainer_cls
method of all registered plugins and returns the
first non-None
trainer class.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
Returns
Trainer | None
The first non-None
trainer class returned by a plugin.
load_datasets
integrations.base.PluginManager.load_datasets(cfg, preprocess= False )
Calls the load_datasets method of each registered plugin.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
preprocess
bool
Whether this is preprocess step of the datasets.
False
Returns
Union['TrainDatasetMeta', None]
The dataset metadata loaded from all registered plugins.
post_lora_load
integrations.base.PluginManager.post_lora_load(cfg, model)
Calls the post_lora_load
method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_model_build
integrations.base.PluginManager.post_model_build(cfg, model)
Calls the post_model_build
method of all registered plugins after the
model has been built / loaded, but before any adapters have been applied.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel
The loaded model.
required
post_model_load
integrations.base.PluginManager.post_model_load(cfg, model)
Calls the post_model_load
method of all registered plugins after the model
has been loaded inclusive of any adapters.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_train
integrations.base.PluginManager.post_train(cfg, model)
Calls the post_train method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel | PeftModel
The loaded model.
required
post_train_unload
integrations.base.PluginManager.post_train_unload(cfg)
Calls the post_train_unload method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
post_trainer_create
integrations.base.PluginManager.post_trainer_create(cfg, trainer)
Calls the post_trainer_create
method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
trainer
Trainer
The trainer object for training.
required
pre_lora_load
integrations.base.PluginManager.pre_lora_load(cfg, model)
Calls the pre_lora_load
method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
model
PreTrainedModel
The loaded model.
required
pre_model_load
integrations.base.PluginManager.pre_model_load(cfg)
Calls the pre_model_load method of all registered plugins.
Parameters
cfg
DictDefault
The configuration for the plugins.
required
register
integrations.base.PluginManager.register(plugin_name)
Registers a new plugin by its name.
Parameters
plugin_name
str
The name of the plugin to be registered.
required
Raises
ImportError
If the plugin module cannot be imported.
Functions
load_plugin
Loads a plugin based on the given plugin name.
load_plugin
integrations.base.load_plugin(plugin_name)
Loads a plugin based on the given plugin name.
The plugin name should be in the format “module_name.class_name”. This function
splits the plugin name into module and class, imports the module, retrieves the
class from the module, and creates an instance of the class.
Parameters
plugin_name
str
The name of the plugin to be loaded. The name should be in the format “module_name.class_name”.
required
Returns
BasePlugin
An instance of the loaded plugin.
Raises
ImportError
If the plugin module cannot be imported.