[perf] introduce deep gemm group_gemm_masked as bmm (#5432)
This commit is contained in:
@@ -44,6 +44,7 @@ else:
|
|||||||
fp8_min = -fp8_max
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
_enable_jit_deepgemm = False
|
_enable_jit_deepgemm = False
|
||||||
|
_enable_jit_deepgemm_bmm = False
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -53,10 +54,11 @@ if _is_cuda:
|
|||||||
)
|
)
|
||||||
|
|
||||||
sm_version = get_device_sm()
|
sm_version = get_device_sm()
|
||||||
if sm_version == 90 and get_bool_env_var(
|
if sm_version == 90:
|
||||||
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
||||||
):
|
_enable_jit_deepgemm = True
|
||||||
_enable_jit_deepgemm = True
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
||||||
|
_enable_jit_deepgemm_bmm = True
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -940,6 +942,108 @@ def per_tensor_quant_mla_fp8(
|
|||||||
return x_q, x_s_out
|
return x_q, x_s_out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
y_ptr,
|
||||||
|
y_q_ptr,
|
||||||
|
y_s_ptr,
|
||||||
|
masked_m_ptr,
|
||||||
|
group_size,
|
||||||
|
y_stride_b,
|
||||||
|
y_stride_t,
|
||||||
|
y_q_stride_b,
|
||||||
|
y_q_stride_t,
|
||||||
|
y_s_stride_b,
|
||||||
|
y_s_stride_g,
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
NUM_GROUP: tl.constexpr,
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""A Triton-accelerated function to perform per-token-group
|
||||||
|
quantization on a tensor for deep_gemm grouped_gemm_masked.
|
||||||
|
This function converts the tensor values into float8 values.
|
||||||
|
y and y_q: (b, t, k)
|
||||||
|
y_s: (b, k//group_size, t)
|
||||||
|
"""
|
||||||
|
t_id = tl.program_id(0)
|
||||||
|
b_id = tl.program_id(1)
|
||||||
|
|
||||||
|
y_ptr += b_id * y_stride_b + t_id * y_stride_t
|
||||||
|
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
|
||||||
|
y_s_ptr += b_id * y_s_stride_b + t_id
|
||||||
|
|
||||||
|
if t_id == 0:
|
||||||
|
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||||
|
mask = cols < group_size
|
||||||
|
|
||||||
|
for gid in range(NUM_GROUP):
|
||||||
|
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
|
||||||
|
tl.float32
|
||||||
|
)
|
||||||
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||||
|
y_s = _absmax / fp8_max
|
||||||
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
|
||||||
|
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int = 128,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function quantizes input values to float8 values with per-token-group-quantization
|
||||||
|
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_max = finfo.max
|
||||||
|
if _is_hip:
|
||||||
|
dtype = torch.float8_e4m3fnuz
|
||||||
|
fp8_max = 224.0
|
||||||
|
|
||||||
|
b, m, k = x.shape
|
||||||
|
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
||||||
|
num_tiles_k = k // group_size
|
||||||
|
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
|
||||||
|
|
||||||
|
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
|
||||||
|
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
|
||||||
|
masked_m = x.new_empty((b,), dtype=torch.int32)
|
||||||
|
|
||||||
|
BLOCK_SIZE = triton.next_power_of_2(group_size)
|
||||||
|
grid = (m, b)
|
||||||
|
|
||||||
|
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
masked_m,
|
||||||
|
group_size,
|
||||||
|
x.stride(0),
|
||||||
|
x.stride(1),
|
||||||
|
x_q.stride(0),
|
||||||
|
x_q.stride(1),
|
||||||
|
x_s.stride(0),
|
||||||
|
x_s.stride(1),
|
||||||
|
eps,
|
||||||
|
-fp8_max,
|
||||||
|
fp8_max,
|
||||||
|
num_tiles_k,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
||||||
|
|
||||||
|
|
||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
_enable_jit_deepgemm_bmm,
|
||||||
|
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||||
|
per_tensor_quant_mla_fp8,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
block_quant_to_tensor_quant,
|
block_quant_to_tensor_quant,
|
||||||
channel_quant_to_tensor_quant,
|
channel_quant_to_tensor_quant,
|
||||||
@@ -82,6 +86,7 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
|
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||||
else:
|
else:
|
||||||
from vllm._custom_ops import awq_dequantize
|
from vllm._custom_ops import awq_dequantize
|
||||||
@@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.w_vc = None
|
self.w_vc = None
|
||||||
self.w_scale = None
|
self.w_scale = None
|
||||||
|
|
||||||
|
self.w_scale_k = None
|
||||||
|
self.w_scale_v = None
|
||||||
|
self.use_deep_gemm_bmm = False
|
||||||
|
|
||||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||||
"flashinfer_mla_disable_ragged"
|
"flashinfer_mla_disable_ragged"
|
||||||
]
|
]
|
||||||
@@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
if self.use_deep_gemm_bmm:
|
||||||
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||||
|
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
q_nope_out = q_nope.new_empty(
|
||||||
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||||
|
)
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
|
(q_nope_val, q_nope_scale),
|
||||||
|
(self.w_kc, self.w_scale_k),
|
||||||
|
q_nope_out,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||||
|
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||||
q_nope_out = torch.bmm(
|
q_nope_out = torch.bmm(
|
||||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||||
@@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
if self.use_deep_gemm_bmm:
|
||||||
|
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
||||||
|
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
attn_bmm_output = attn_output.new_empty(
|
||||||
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||||
|
)
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
|
(attn_output_val, attn_output_scale),
|
||||||
|
(self.w_vc, self.w_scale_v),
|
||||||
|
attn_bmm_output,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
||||||
|
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||||
attn_bmm_output = torch.bmm(
|
attn_bmm_output = torch.bmm(
|
||||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||||
@@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
w = self_attn.kv_b_proj.weight
|
w = self_attn.kv_b_proj.weight
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||||
# This may affect the accuracy of fp8 model.
|
# This may affect the accuracy of fp8 model.
|
||||||
|
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
||||||
|
use_deep_gemm_bmm = False
|
||||||
|
model_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
if w.dtype in (
|
if w.dtype in (
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
torch.float8_e4m3fnuz,
|
torch.float8_e4m3fnuz,
|
||||||
@@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
weight = w
|
weight = w
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||||
|
|
||||||
w, scale = block_quant_to_tensor_quant(
|
if (
|
||||||
weight, weight_scale, weight_block_size
|
_is_cuda
|
||||||
)
|
and _enable_jit_deepgemm_bmm
|
||||||
self_attn.w_scale = scale
|
and weight_block_size[0] == 128
|
||||||
|
and weight_block_size[1] == 128
|
||||||
|
and model_dtype == torch.bfloat16
|
||||||
|
):
|
||||||
|
block_scale = weight_scale
|
||||||
|
use_deep_gemm_bmm = True
|
||||||
|
else:
|
||||||
|
w, scale = block_quant_to_tensor_quant(
|
||||||
|
weight, weight_scale, weight_block_size
|
||||||
|
)
|
||||||
|
self_attn.w_scale = scale
|
||||||
else:
|
else:
|
||||||
weight = w
|
weight = w
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale
|
weight_scale = self_attn.kv_b_proj.weight_scale
|
||||||
@@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
||||||
torch.bfloat16
|
torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
w_kc, w_vc = w.unflatten(
|
w_kc, w_vc = w.unflatten(
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
if not use_deep_gemm_bmm:
|
||||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
if (
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
if (
|
||||||
and self_attn.w_scale is None
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||||
):
|
and self_attn.w_scale is None
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
):
|
||||||
if _is_hip:
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||||
self_attn.w_scale *= 2.0
|
if _is_hip:
|
||||||
|
self_attn.w_scale *= 2.0
|
||||||
|
else:
|
||||||
|
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
||||||
|
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
||||||
|
ws_kc, ws_vc = block_scale.unflatten(
|
||||||
|
0, (-1, (num_tiles_k + num_tiles_n))
|
||||||
|
).split([num_tiles_k, num_tiles_n], dim=1)
|
||||||
|
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
||||||
|
self_attn.w_scale_v = ws_vc.contiguous()
|
||||||
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
||||||
|
self_attn.w_vc = w_vc.contiguous()
|
||||||
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||||
per_tensor_quant_mla_fp8,
|
per_tensor_quant_mla_fp8,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
static_quant_fp8,
|
static_quant_fp8,
|
||||||
@@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase):
|
|||||||
self._per_tensor_quant_mla_fp8(*params)
|
self._per_tensor_quant_mla_fp8(*params)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||||
|
B = [128]
|
||||||
|
NUM_TOKENS = [7, 83, 2048, 1024 * 16]
|
||||||
|
D = [512, 128]
|
||||||
|
GROUP_SIZE = [128]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
self, b, num_tokens, d, dtype, group_size, seed
|
||||||
|
):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
x = torch.rand(b, num_tokens, d, dtype=dtype)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
|
||||||
|
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||||
|
x, group_size
|
||||||
|
)
|
||||||
|
out = out[:, :num_tokens, :]
|
||||||
|
scale = scale[:, :num_tokens, :]
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20, atol=1e-2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(scale, ref_scale))
|
||||||
|
|
||||||
|
def test_per_token_group_quant_mla_deep_gemm_masked_fp8(self):
|
||||||
|
for params in itertools.product(
|
||||||
|
self.B,
|
||||||
|
self.NUM_TOKENS,
|
||||||
|
self.D,
|
||||||
|
self.DTYPES,
|
||||||
|
self.GROUP_SIZE,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
b=params[0],
|
||||||
|
num_tokens=params[1],
|
||||||
|
d=params[2],
|
||||||
|
dtype=params[3],
|
||||||
|
group_size=params[4],
|
||||||
|
seed=params[5],
|
||||||
|
):
|
||||||
|
self._per_token_group_quant_mla_deep_gemm_masked_fp8(*params)
|
||||||
|
|
||||||
|
|
||||||
# For test
|
# For test
|
||||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||||
@@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|||||||
self._w8a8_block_fp8_fused_moe(*params)
|
self._w8a8_block_fp8_fused_moe(*params)
|
||||||
|
|
||||||
|
|
||||||
|
# For test
|
||||||
|
def torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_shape, out_dtype):
|
||||||
|
"""This function performs bmm with block-wise quantization using native torch."""
|
||||||
|
|
||||||
|
B, N, _ = w.shape
|
||||||
|
_, M, _ = a.shape
|
||||||
|
out = torch.empty((B, M, N), dtype=out_dtype, device=a.device)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
out[i] = native_w8a8_block_fp8_matmul(
|
||||||
|
a[i], w[i], a_s[i], w_s[i], block_shape, output_dtype=out_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
||||||
|
DTYPES = [torch.bfloat16]
|
||||||
|
M = [1, 33, 64, 222, 8192]
|
||||||
|
N = [128, 512]
|
||||||
|
K = [128, 512]
|
||||||
|
BATCH = [128]
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
try:
|
||||||
|
import deep_gemm
|
||||||
|
except ImportError:
|
||||||
|
raise unittest.SkipTest("DeepGEMM is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _w8a8_block_fp8_batched_deep_gemm(self, M, N, K, B, block_size, dtype, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
a_fp32 = torch.randn((B, M, K), dtype=torch.float32) / 10
|
||||||
|
a = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
w_fp32 = (torch.rand((B, N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
w = w_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles_w = (N + block_n - 1) // block_n
|
||||||
|
k_tiles_w = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
w_s = (
|
||||||
|
torch.rand((B, n_tiles_w, k_tiles_w), dtype=torch.float32)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
a_s = torch.rand((B, M, k_tiles_w), dtype=torch.float32) * factor_for_scale
|
||||||
|
|
||||||
|
ae = a.new_empty(B, (M + 255) // 256 * 256, K)
|
||||||
|
ae_s = a_s.new_empty(B, (M + 255) // 256 * 256, k_tiles_w)
|
||||||
|
oe = torch.empty((B, (M + 255) // 256 * 256, N), dtype=dtype)
|
||||||
|
ae[:, :M, :] = a
|
||||||
|
ae_s[:, :M, :] = a_s
|
||||||
|
|
||||||
|
masked_m = torch.full((B,), M, dtype=torch.int)
|
||||||
|
expected_m = M
|
||||||
|
lhs = (
|
||||||
|
ae,
|
||||||
|
ae_s,
|
||||||
|
)
|
||||||
|
rhs = (
|
||||||
|
w,
|
||||||
|
w_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||||
|
out = oe[:, :M, :]
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||||
|
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||||
|
< 0.0001
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_w8a8_block_fp8_batched_deep_gemm(self):
|
||||||
|
|
||||||
|
for params in itertools.product(
|
||||||
|
self.M,
|
||||||
|
self.N,
|
||||||
|
self.K,
|
||||||
|
self.BATCH,
|
||||||
|
self.BLOCK_SIZE,
|
||||||
|
self.DTYPES,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
M=params[0],
|
||||||
|
N=params[1],
|
||||||
|
K=params[2],
|
||||||
|
B=params[3],
|
||||||
|
block_size=params[4],
|
||||||
|
dtype=params[5],
|
||||||
|
seed=params[6],
|
||||||
|
):
|
||||||
|
self._w8a8_block_fp8_batched_deep_gemm(*params)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user