monkeypatch.torchao_optim

monkeypatch.torchao_optim

Patch for torchao optim subclasses that crash under torch.compile.

torchao 0.17.0 PR #3934 added an “appearance dtype” to OptimState{4,8}bit and OptimStateFp8, allowing them to report as e.g. bf16 while internally storing quantized codes. Three issues:

  1. aten.view.default doesn’t propagate the appearance dtype, so views (e.g. from DTensor.from_local()) revert to float32 while the base is bf16. torch.compile’s fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).

  2. aten._to_copy doesn’t clone internal tensors, so same-device dtype changes (e.g. .float()) create an accidental view relationship with the same issue.

  3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes with NotImplementedError.

Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.

Upstream fix: https://github.com/pytorch/ao/pull/4216

Functions

Name Description
patch_torchao_optim_state_8bit Patch torchao optim subclasses for torch.compile compatibility.

patch_torchao_optim_state_8bit

monkeypatch.torchao_optim.patch_torchao_optim_state_8bit()

Patch torchao optim subclasses for torch.compile compatibility.