Support new DeepGEMM (#7172)
This commit is contained in:
@@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE):
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
@@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE):
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||
|
||||
@@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
use_fp8=use_fp8,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
round_scale=deep_gemm_wrapper.DEEPGEMM_V202506,
|
||||
use_ue8m0=deep_gemm_wrapper.DEEPGEMM_V202506,
|
||||
)
|
||||
)
|
||||
return packed_recv_hidden, packed_recv_count, event, hook
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.math_utils import ceil_div
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
@@ -518,10 +519,6 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
|
||||
@@ -21,6 +21,12 @@ def _compute_enable_deep_gemm():
|
||||
|
||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||
|
||||
DEEPGEMM_V202506 = False
|
||||
try:
|
||||
from deep_gemm import fp8_gemm_nt
|
||||
|
||||
# They have not given a name to this breaking change
|
||||
DEEPGEMM_V202506 = True
|
||||
except ImportError:
|
||||
DEEPGEMM_V202506 = False
|
||||
|
||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
|
||||
|
||||
@@ -16,14 +16,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
|
||||
if DEEPGEMM_V202506:
|
||||
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import (
|
||||
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
else:
|
||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
|
||||
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
|
||||
@@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs(
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
assert A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
|
||||
if As.dtype == torch.float:
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
elif Bs.dtype == torch.int:
|
||||
assert (
|
||||
triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
|
||||
), f"{A.shape=} {As.shape=} {block_size=}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
@@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs(
|
||||
assert B.is_contiguous()
|
||||
assert Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
if Bs.dtype == torch.float:
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
elif Bs.dtype == torch.int:
|
||||
assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
|
||||
assert (
|
||||
triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
|
||||
), f"{B.shape=} {Bs.shape=} {block_size=}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
@@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
|
||||
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
||||
|
||||
@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
@@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
if False: # TODO (pr-chain)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
self._weight_requant_ue8m0()
|
||||
|
||||
def _weight_requant_ue8m0(self):
|
||||
|
||||
Reference in New Issue
Block a user