Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -124,7 +124,6 @@ class PyTorchExtensionTest(unittest.TestCase):
dtype = torch.float16
plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
op = plan.construct()
with tempfile.TemporaryDirectory() as tmpdir:
@ -132,7 +131,7 @@ class PyTorchExtensionTest(unittest.TestCase):
A, B, C, _ = _initialize(dtype, 1024, 256, 512)
D_ref = torch.nn.functional.relu(A @ B)
D_ref = A @ B
D = mod.run(A, B)
assert torch.allclose(D, D_ref)
@ -147,7 +146,7 @@ class PyTorchExtensionTest(unittest.TestCase):
alpha = 2.0
beta = -1.0
D_ref = torch.nn.functional.relu((A @ B) * alpha + (beta * C))
D_ref = (A @ B) * alpha + (beta * C)
D = mod.run(A, B, C, alpha, beta)
assert torch.allclose(D, D_ref)