Remove one kernel in per_tensor_quant_mla_fp8 (#5549)

This commit is contained in:
fzyzcjy
2025-04-20 06:08:15 +08:00
committed by GitHub
parent d58e354472
commit 613b197e57
4 changed files with 62 additions and 18 deletions

View File

@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():

View File

@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal(
self,
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
)
attn_bmm_output = bmm_fp8(
attn_output_val,
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
attn_bmm_output = bmm_fp8(
attn_output_val,
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
else:
raise NotImplementedError
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
# Gather
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
if self.attn_tp_size != 1:
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is None: