Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -124,7 +124,6 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
|
||||
dtype = torch.float16
|
||||
plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor)
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
op = plan.construct()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -132,7 +131,7 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
|
||||
A, B, C, _ = _initialize(dtype, 1024, 256, 512)
|
||||
|
||||
D_ref = torch.nn.functional.relu(A @ B)
|
||||
D_ref = A @ B
|
||||
D = mod.run(A, B)
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
@ -147,7 +146,7 @@ class PyTorchExtensionTest(unittest.TestCase):
|
||||
|
||||
alpha = 2.0
|
||||
beta = -1.0
|
||||
D_ref = torch.nn.functional.relu((A @ B) * alpha + (beta * C))
|
||||
D_ref = (A @ B) * alpha + (beta * C)
|
||||
D = mod.run(A, B, C, alpha, beta)
|
||||
assert torch.allclose(D, D_ref)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user