utils.generation.sft
utils.generation.sft
Sample generation utilities for SFT/Pretrain training.
Functions
generate_samples
utils.generation.sft.generate_samples(
model,
tokenizer,
dataloader,
num_generation_samples=3,
max_new_tokens=50,
temperature=0.7,
top_p=None,
top_k=None,
do_sample=True,
prompt_ratio=0.5,
)
Generate samples from the model during training for monitoring.
Parameters
| model |
torch.nn.Module |
The model to generate from |
required |
| tokenizer |
Any |
The tokenizer to use for encoding/decoding |
required |
| dataloader |
Any |
Dataloader to sample prompts from |
required |
| num_generation_samples |
int |
Number of samples to generate |
3 |
| max_new_tokens |
int |
Maximum new tokens to generate |
50 |
| temperature |
float |
Sampling temperature (0.0 = greedy) |
0.7 |
| top_p |
Optional[float] |
Nucleus sampling parameter |
None |
| top_k |
Optional[int] |
Top-k sampling parameter |
None |
| do_sample |
bool |
Whether to use sampling vs greedy decoding |
True |
| prompt_ratio |
float |
Ratio of sequence to use as prompt (0.0-1.0) |
0.5 |
Returns
|
List[dict] |
List of dicts with ‘prompt’, ‘generated’, and ‘full_text’ keys |