[NVIDIA] Fix missing get_col_major_tma_aligned_tensor for Blackwell deepgemm in EpMoE (#8955)
This commit is contained in:
@@ -55,6 +55,22 @@ if _use_aiter:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||
@torch.compile
|
||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||
temp = x.to(torch.float32).view(torch.int32)
|
||||
exp = torch.bitwise_right_shift(temp, 23)
|
||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||
is_ru = torch.logical_and(
|
||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||
)
|
||||
exp = torch.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.to(torch.uint8).view(torch.int)
|
||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
@@ -204,10 +220,22 @@ class EPMoE(FusedMoE):
|
||||
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
b, s_mn, s_k = gateup_input_scale.shape
|
||||
assert (
|
||||
s_mn % 4 == 0 and s_k % 4 == 0
|
||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_input_fp8 = (
|
||||
gateup_input,
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
||||
(
|
||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
gateup_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
num_groups, m, k = gateup_input_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
@@ -215,7 +243,12 @@ class EPMoE(FusedMoE):
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
gateup_input_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
@@ -246,6 +279,7 @@ class EPMoE(FusedMoE):
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
@@ -253,13 +287,24 @@ class EPMoE(FusedMoE):
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
|
||||
Reference in New Issue
Block a user