Refactor attention into multiple stages (#6477)
This commit is contained in:
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
|
||||
def op_prepare(self, state):
|
||||
state.attn_intermediate_state = self.forward_prepare(
|
||||
positions=state.positions,
|
||||
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
||||
forward_batch=state.forward_batch,
|
||||
zero_allocator=state.zero_allocator,
|
||||
)
|
||||
|
||||
def op_core(self, state):
|
||||
state.hidden_states_after_attn = self.forward_core(
|
||||
state.pop("attn_intermediate_state")
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states
|
||||
):
|
||||
s = self.forward_prepare(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
return self.forward_core(s)
|
||||
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||
return self.forward_normal_chunked_kv(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_normal(
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states, None, forward_batch, None
|
||||
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
inner_state = self.forward_normal_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||
inner_state = self.forward_normal_chunked_kv_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
inner_state = self.forward_absorb_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return None, attn_forward_method, forward_batch, inner_state
|
||||
|
||||
def forward_core(self, intermediate_state):
|
||||
hidden_states, attn_forward_method, forward_batch, inner_state = (
|
||||
intermediate_state
|
||||
)
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
return self.forward_normal_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||
return self.forward_normal_chunked_kv_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
return self.forward_absorb_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_normal_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
if self.q_lora_rank is not None:
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
|
||||
return q, k, v, forward_batch
|
||||
|
||||
def forward_normal_core(self, q, k, v, forward_batch):
|
||||
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward_absorb(
|
||||
def forward_absorb_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
):
|
||||
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_absorb_fused_mla_rope(
|
||||
def forward_absorb_fused_mla_rope_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
enable_rope_fusion = (
|
||||
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
||||
)
|
||||
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
||||
|
||||
return (
|
||||
q_input,
|
||||
key_cache_buf,
|
||||
val_cache_buf,
|
||||
attn_output,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
k_pe_output,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
attn_logits,
|
||||
num_kv_split,
|
||||
sm_scale,
|
||||
enable_rope_fusion,
|
||||
k_input,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
)
|
||||
|
||||
def forward_absorb_fused_mla_rope_core(
|
||||
self,
|
||||
q_input,
|
||||
key_cache_buf,
|
||||
val_cache_buf,
|
||||
attn_output,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
k_pe_output,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
attn_logits,
|
||||
num_kv_split,
|
||||
sm_scale,
|
||||
enable_rope_fusion,
|
||||
k_input,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
):
|
||||
decode_attention_fwd_grouped_rope(
|
||||
q_input,
|
||||
key_cache_buf,
|
||||
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return accum_output
|
||||
|
||||
def forward_normal_chunked_kv(
|
||||
def forward_normal_chunked_kv_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
||||
# To avoid this, we split the kv cache into chunks and process them one after another.
|
||||
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
||||
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
|
||||
return q, k, v, forward_batch
|
||||
|
||||
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
||||
# Do mha for extended part without prefix
|
||||
forward_batch.set_attn_attend_prefix_cache(False)
|
||||
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
def op_attn(self, state):
|
||||
state.hidden_states_after_attn = self.self_attn(
|
||||
positions=state.positions,
|
||||
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
||||
forward_batch=state.forward_batch,
|
||||
zero_allocator=state.zero_allocator,
|
||||
)
|
||||
|
||||
def op_comm_prepare_mlp(self, state):
|
||||
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
||||
self.layer_communicator.prepare_mlp(
|
||||
|
||||
@@ -7,7 +7,8 @@ def compute_layer_operations(
|
||||
if not layer.is_layer_sparse:
|
||||
return [
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.op_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.op_mlp,
|
||||
layer.op_comm_postprocess_layer,
|
||||
@@ -16,7 +17,8 @@ def compute_layer_operations(
|
||||
# Will add TBO operation orders here
|
||||
return [
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.op_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_shared_experts,
|
||||
|
||||
Reference in New Issue
Block a user