Skip to content

Commit

Permalink
mx cleanup [2/x]: refactor mx gemm (pytorch#1593)
Browse files Browse the repository at this point in the history
* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Jan 24, 2025
1 parent 11440c2 commit 6b472e5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 48 deletions.
31 changes: 22 additions & 9 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run_around_tests():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)])
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
def test_linear_eager(elem_dtype, bias, input_shape):
"""
Smoke test for training linear module with mx weight
Expand All @@ -48,7 +48,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
grad_shape[-1] = 6

m = nn.Sequential(
nn.Linear(4, 6, bias=bias, device="cuda"),
nn.Linear(8, 6, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
Expand All @@ -71,7 +71,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
if elem_dtype is torch.float8_e4m3fn:
assert y_sqnr >= 18.0
assert w_g_sqnr >= 18.0
assert x_g_sqnr >= 14.0
assert x_g_sqnr >= 12.0
else:
assert y_sqnr >= 8.0
assert w_g_sqnr >= 10.0
Expand Down Expand Up @@ -101,28 +101,41 @@ def test_activation_checkpointing():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
def test_linear_compile(elem_dtype, bias):
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize(
"use_autocast",
[
False,
],
)
def test_linear_compile(elem_dtype, bias, use_autocast):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
input_shape = (2, 4)
grad_shape = (2, 6)
M, K, N = 4, 8, 6
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(4, 6, bias=bias, device="cuda"),
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")

with torch.autocast("cuda", dtype=torch.bfloat16):
if use_autocast:
with torch.autocast("cuda", dtype=torch.bfloat16):
y_ref = m_mx(x_ref)
y = m_mx_c(x)
else:
y_ref = m_mx(x_ref)
y = m_mx_c(x)
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)
Expand Down
3 changes: 2 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def test_transpose(elem_dtype, fp4_triton):
if elem_dtype != DTYPE_FP4 and fp4_triton:
pytest.skip("unsupported configuration")

tensor_hp = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
M, K = 128, 256
block_size = 32
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
Expand Down
101 changes: 75 additions & 26 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,81 @@
# LICENSE file in the root directory of this source tree.

"""
Defines the UX for converting a model to use mx weights
For now, this is a module swap for speed of iteration.
Eventually we plan to move this to a tensor subclass weight wrapper for
inference, and to a tensor subclass weight wrapper + module hooks for training.
Defines the prototype UX for converting a model to use mx weights
"""

from typing import Any

import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
from torchao.prototype.mx_formats.mx_tensor import MXTensor


@torch._dynamo.allow_in_graph
class NoopFwToMXBw(torch.autograd.Function):
"""
Forward: no-op
Backward: cast grad to MX
"""
class mx_mm(torch.autograd.Function):
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)

@staticmethod
def forward(ctx, x, elem_dtype, block_size):
def forward(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
elem_dtype: Any,
block_size: int,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.elem_dtype = elem_dtype
ctx.block_size = block_size
return x

# input @ weight_t = output
input_orig_shape = input_hp.shape
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])

input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size)
weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size)
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])

return output

@staticmethod
def backward(ctx, g):
scale, data = to_mx(g, ctx.elem_dtype, ctx.block_size)
return (
MXTensor(scale, data, ctx.elem_dtype, ctx.block_size, g.dtype),
None,
None,
def backward(ctx, grad_output_hp: torch.Tensor):
input_hp, weight_hp = ctx.saved_tensors
weight_hp_t_c = weight_hp.t().contiguous()
elem_dtype = ctx.elem_dtype
block_size = ctx.block_size

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])

input_hp_orig_shape = input_hp.shape
input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1])

# grad_output @ weight = grad_input
grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size)
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size)
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)

# input_t @ grad_output = grad_weight
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(), elem_dtype, block_size
)
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(), elem_dtype, block_size
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)

return grad_input, grad_weight, None, None


class MXLinear(torch.nn.Linear):
"""
Expand All @@ -59,16 +98,26 @@ def from_float(cls, mod, elem_dtype, block_size):
return mod

def forward(self, x):
x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size)
w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size)
y = F.linear(x_mx, w_mx, self.bias)
y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
if torch.is_autocast_enabled():
# special case autocast
autocast_dtype = torch.get_autocast_dtype("cuda")
x = x.to(autocast_dtype)
w = self.weight.to(autocast_dtype)
else:
w = self.weight

y = mx_mm.apply(x, w, self.elem_dtype, self.block_size)
if self.bias is not None:
y = y + self.bias
return y


class MXInferenceLinear(torch.nn.Linear):
"""
Inference version of MXLinear, with the weight pre-quantized to MX.
Note: this is weight-only quantization, with the gemm being executed
in high precision.
"""

@classmethod
Expand All @@ -84,8 +133,8 @@ def from_float(cls, mod, elem_dtype, block_size):
# TODO(future PR): set to new_mod.weight directly, will need to work
# through some errors
new_mod.weight_mx = MXTensor.to_mx(
mod.weight.t().contiguous(), elem_dtype, block_size=block_size
).t()
mod.weight, elem_dtype, block_size=block_size
)
new_mod.bias = mod.bias
new_mod.elem_dtype = elem_dtype
return new_mod
Expand Down
15 changes: 3 additions & 12 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,13 @@ def mx_mm(aten_op, args, kwargs=None):
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)
return res


@implements([aten.addmm.default])
def mx_addmm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
c = args[2]
assert isinstance(b, MXTensor) and isinstance(c, MXTensor)
b_hp = b.to_dtype(b._orig_dtype)
c_hp = c.to_dtype(c._orig_dtype)
res = aten_op(a, b_hp, c_hp)
return res


@implements([aten.t.default])
def mx_t(aten_op, args, kwargs=None):
# For now, only transpose(input, 0, 1) is supported.
Expand Down
7 changes: 7 additions & 0 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,20 @@ def __new__(
new_size = data_bits.size()
if elem_dtype == DTYPE_FP4:
# set the tensor size to what it would be without 2x4 packing
# Note: `is_contiguous` is going to return True for a tensor of size
# (M, 1) regardless or the order of dims, so this logic is currently
# broken for tensors of size (M, 1) or (1, M). Leaving broken until
# a time when fixing this becomes important.
new_size = tensor_size_fp4x2_to_hp(
new_size,
data_bits.is_contiguous(),
)
self = torch.Tensor._make_wrapper_subclass(
cls,
new_size,
strides=data_bits.stride(),
storage_offset=data_bits.storage_offset(),
layout=data_bits.layout,
dtype=orig_dtype,
device=data_bits.device,
)
Expand Down

0 comments on commit 6b472e5

Please sign in to comment.