v4.2 tag release. (#2638)
This commit is contained in:
@ -985,33 +985,38 @@ ${compile_guard_end}
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
grouped = is_grouped(operation.gemm_kind)
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
if is_tma_epilogue(operation.epilogue_schedule):
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
if is_tma_epilogue(operation.epilogue_schedule):
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103:
|
||||
# SM103 FP4 Ultra
|
||||
is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped)
|
||||
]
|
||||
is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped)
|
||||
]
|
||||
if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103:
|
||||
if is_tma_epilogue(operation.epilogue_schedule):
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
||||
if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
if is_tma_epilogue(operation.epilogue_schedule):
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
Reference in New Issue
Block a user