Remove one kernel in per_tensor_quant_mla_fp8 (#5549)
This commit is contained in:
@@ -58,10 +58,8 @@ if _is_cuda:
|
||||
):
|
||||
_enable_jit_deepgemm = True
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
@@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2(
|
||||
|
||||
|
||||
def per_tensor_quant_mla_fp8(
|
||||
x: torch.Tensor, eps: float = 1e-12
|
||||
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with tensor-wise quantization
|
||||
and specialized for mla absorbed case.
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
assert (
|
||||
x_s_out.shape == (1,)
|
||||
and x_s_out.dtype == torch.float32
|
||||
and x_s_out.device == x.device
|
||||
)
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
||||
@@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8(
|
||||
|
||||
_per_tensor_quant_mla_fp8_stage1[grid](
|
||||
x,
|
||||
x_s,
|
||||
x_s_out,
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
@@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
|
||||
)
|
||||
_per_tensor_quant_mla_fp8_stage2[grid](
|
||||
x,
|
||||
x_s,
|
||||
x_s_out,
|
||||
x_q,
|
||||
num_seq,
|
||||
head_size,
|
||||
@@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
return x_q, x_s_out
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator = BumpAllocator(
|
||||
buffer_size=2,
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.decoder(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
|
||||
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttnForwardMethod(IntEnum):
|
||||
|
||||
# Use multi-head attention
|
||||
MHA = auto()
|
||||
|
||||
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
enable_rope_fusion = (
|
||||
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
||||
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
attn_output.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
# Gather
|
||||
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator = BumpAllocator(
|
||||
# TODO for two-batch-overlap, we need a larger buffer size
|
||||
buffer_size=len(self.layers) * 2,
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
|
||||
@@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
|
||||
"MistralForCausalLM",
|
||||
}
|
||||
return architectures[0] in default_archs
|
||||
|
||||
|
||||
# Can be more general if it is used in multiple places (keep it simple and thus not general now)
|
||||
class BumpAllocator:
|
||||
def __init__(self, buffer_size: int, dtype, device):
|
||||
self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
|
||||
self._pointer = 0
|
||||
|
||||
def allocate(self, size: int):
|
||||
assert self._pointer + size <= len(self._buffer)
|
||||
output = self._buffer[self._pointer : self._pointer + size]
|
||||
self._pointer += size
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user