Support new DeepGEMM (#7172)
This commit is contained in:
@@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
down_input_scale,
|
down_input_scale,
|
||||||
scale_block_size,
|
scale_block_size,
|
||||||
masked_m,
|
masked_m,
|
||||||
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
)
|
)
|
||||||
del gateup_output
|
del gateup_output
|
||||||
|
|
||||||
@@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
n = self.w2_weight.size(1)
|
n = self.w2_weight.size(1)
|
||||||
down_input_fp8 = (
|
down_input_fp8 = (
|
||||||
down_input,
|
down_input,
|
||||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
(
|
||||||
|
down_input_scale
|
||||||
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
|
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||||
|
down_input_scale
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||||
|
|||||||
@@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
async_finish=not self.return_recv_hook,
|
async_finish=not self.return_recv_hook,
|
||||||
return_recv_hook=self.return_recv_hook,
|
return_recv_hook=self.return_recv_hook,
|
||||||
|
round_scale=deep_gemm_wrapper.DEEPGEMM_V202506,
|
||||||
|
use_ue8m0=deep_gemm_wrapper.DEEPGEMM_V202506,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return packed_recv_hidden, packed_recv_count, event, hook
|
return packed_recv_hidden, packed_recv_count, event, hook
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.math_utils import ceil_div
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
@@ -518,10 +519,6 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def moe_align_block_size_stage1(
|
def moe_align_block_size_stage1(
|
||||||
topk_ids_ptr,
|
topk_ids_ptr,
|
||||||
|
|||||||
@@ -21,6 +21,12 @@ def _compute_enable_deep_gemm():
|
|||||||
|
|
||||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||||
|
|
||||||
DEEPGEMM_V202506 = False
|
try:
|
||||||
|
from deep_gemm import fp8_gemm_nt
|
||||||
|
|
||||||
|
# They have not given a name to this breaking change
|
||||||
|
DEEPGEMM_V202506 = True
|
||||||
|
except ImportError:
|
||||||
|
DEEPGEMM_V202506 = False
|
||||||
|
|
||||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
|
||||||
|
|||||||
@@ -16,14 +16,24 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
if ENABLE_JIT_DEEPGEMM:
|
if ENABLE_JIT_DEEPGEMM:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
|
||||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
if DEEPGEMM_V202506:
|
||||||
from deep_gemm import (
|
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
from deep_gemm import (
|
||||||
)
|
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||||
from deep_gemm import (
|
)
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
from deep_gemm import (
|
||||||
)
|
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||||
|
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||||
|
from deep_gemm import (
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||||
|
)
|
||||||
|
from deep_gemm import (
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def grouped_gemm_nt_f8f8bf16_masked(
|
def grouped_gemm_nt_f8f8bf16_masked(
|
||||||
|
|||||||
@@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs(
|
|||||||
assert A.shape[-1] == B.shape[-1]
|
assert A.shape[-1] == B.shape[-1]
|
||||||
assert A.shape[:-1] == As.shape[:-1]
|
assert A.shape[:-1] == As.shape[:-1]
|
||||||
assert A.is_contiguous()
|
assert A.is_contiguous()
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
||||||
|
if As.dtype == torch.float:
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
|
elif Bs.dtype == torch.int:
|
||||||
|
assert (
|
||||||
|
triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
|
||||||
|
), f"{A.shape=} {As.shape=} {block_size=}"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
M = A.numel() // A.shape[-1]
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
@@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs(
|
|||||||
assert B.is_contiguous()
|
assert B.is_contiguous()
|
||||||
assert Bs.ndim == 2
|
assert Bs.ndim == 2
|
||||||
N, K = B.shape
|
N, K = B.shape
|
||||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
|
||||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
if Bs.dtype == torch.float:
|
||||||
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
elif Bs.dtype == torch.int:
|
||||||
|
assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
|
||||||
|
assert (
|
||||||
|
triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
|
||||||
|
), f"{B.shape=} {Bs.shape=} {block_size=}"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
C_shape = A.shape[:-1] + (N,)
|
C_shape = A.shape[:-1] + (N,)
|
||||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
block_size[1],
|
block_size[1],
|
||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
scale_tma_aligned=True,
|
scale_tma_aligned=True,
|
||||||
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
@@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
||||||
self_attn.use_deep_gemm_bmm = True
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
if False: # TODO (pr-chain)
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
self._weight_requant_ue8m0()
|
self._weight_requant_ue8m0()
|
||||||
|
|
||||||
def _weight_requant_ue8m0(self):
|
def _weight_requant_ue8m0(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user