utils.generation.sft

utils.generation.sft

Sample generation utilities for SFT/Pretrain training.

Functions

Name Description
format_generation_for_logging Format a generation sample for pretty logging.
generate_samples Generate samples from the model during training for monitoring.

format_generation_for_logging

utils.generation.sft.format_generation_for_logging(sample, sample_idx, step)

Format a generation sample for pretty logging.

Parameters

Name Type Description Default
sample dict Dict with ‘prompt’, ‘generated’, and ‘full_text’ keys required
sample_idx int Index of the sample required
step int Current training step required

Returns

Name Type Description
tuple[str, str] Tuple of (console_text, wandb_text)

generate_samples

utils.generation.sft.generate_samples(
    model,
    tokenizer,
    dataloader,
    num_generation_samples=3,
    max_new_tokens=50,
    temperature=0.7,
    top_p=None,
    top_k=None,
    do_sample=True,
    prompt_ratio=0.5,
)

Generate samples from the model during training for monitoring.

Parameters

Name Type Description Default
model torch.nn.Module The model to generate from required
tokenizer Any The tokenizer to use for encoding/decoding required
dataloader Any Dataloader to sample prompts from required
num_generation_samples int Number of samples to generate 3
max_new_tokens int Maximum new tokens to generate 50
temperature float Sampling temperature (0.0 = greedy) 0.7
top_p Optional[float] Nucleus sampling parameter None
top_k Optional[int] Top-k sampling parameter None
do_sample bool Whether to use sampling vs greedy decoding True
prompt_ratio float Ratio of sequence to use as prompt (0.0-1.0) 0.5

Returns

Name Type Description
List[dict] List of dicts with ‘prompt’, ‘generated’, and ‘full_text’ keys