Remove one kernel in per_tensor_quant_mla_fp8 (#5549)
This commit is contained in:
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
@@ -91,6 +91,12 @@ class DeepseekModelNextN(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator = BumpAllocator(
|
||||
buffer_size=2,
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
@@ -108,7 +114,7 @@ class DeepseekModelNextN(nn.Module):
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.decoder(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
|
||||
@@ -76,7 +76,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
@@ -97,7 +97,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttnForwardMethod(IntEnum):
|
||||
|
||||
# Use multi-head attention
|
||||
MHA = auto()
|
||||
|
||||
@@ -588,6 +587,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
@@ -613,9 +613,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
@@ -664,6 +668,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -688,6 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -719,6 +725,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -739,6 +746,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
enable_rope_fusion = (
|
||||
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
||||
@@ -765,7 +773,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -861,7 +871,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
attn_output.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -1113,14 +1125,15 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -1131,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
@@ -1151,6 +1165,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
# Gather
|
||||
@@ -1198,6 +1213,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
@@ -1223,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
@@ -1310,6 +1327,12 @@ class DeepseekV2Model(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator = BumpAllocator(
|
||||
# TODO for two-batch-overlap, we need a larger buffer size
|
||||
buffer_size=len(self.layers) * 2,
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@@ -1321,7 +1344,7 @@ class DeepseekV2Model(nn.Module):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
|
||||
Reference in New Issue
Block a user