monkeypatch.torchao_optim
monkeypatch.torchao_optim
Patch for torchao optim subclasses that crash under torch.compile.
torchao 0.17.0 PR #3934 added an “appearance dtype” to OptimState{4,8}bit and OptimStateFp8, allowing them to report as e.g. bf16 while internally storing quantized codes. Three issues:
aten.view.default doesn’t propagate the appearance dtype, so views (e.g. from DTensor.from_local()) revert to float32 while the base is bf16. torch.compile’s fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).
aten._to_copy doesn’t clone internal tensors, so same-device dtype changes (e.g. .float()) create an accidental view relationship with the same issue.
aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes with NotImplementedError.
Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.
Upstream fix: https://github.com/pytorch/ao/pull/4216
Functions
| Name | Description |
|---|---|
| patch_torchao_optim_state_8bit | Patch torchao optim subclasses for torch.compile compatibility. |
patch_torchao_optim_state_8bit
monkeypatch.torchao_optim.patch_torchao_optim_state_8bit()Patch torchao optim subclasses for torch.compile compatibility.