[TPU] remove transpose ops in moe kernel (#18923)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-05-29 16:00:11 -07:00
committed by GitHub
parent a521ef06e5
commit a1cc9f33a3
3 changed files with 8 additions and 13 deletions

View File

@ -26,7 +26,7 @@ TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)