Refactor attention into multiple stages (#6477)
This commit is contained in:
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
|
|
||||||
|
def op_prepare(self, state):
|
||||||
|
state.attn_intermediate_state = self.forward_prepare(
|
||||||
|
positions=state.positions,
|
||||||
|
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
||||||
|
forward_batch=state.forward_batch,
|
||||||
|
zero_allocator=state.zero_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def op_core(self, state):
|
||||||
|
state.hidden_states_after_attn = self.forward_core(
|
||||||
|
state.pop("attn_intermediate_state")
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
if hidden_states.shape[0] == 0:
|
s = self.forward_prepare(
|
||||||
assert (
|
positions=positions,
|
||||||
not self.o_proj.reduce_results
|
hidden_states=hidden_states,
|
||||||
), "short-circuiting allreduce will lead to hangs"
|
forward_batch=forward_batch,
|
||||||
return hidden_states
|
zero_allocator=zero_allocator,
|
||||||
|
)
|
||||||
|
return self.forward_core(s)
|
||||||
|
|
||||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
def forward_prepare(
|
||||||
|
|
||||||
if attn_forward_method == AttnForwardMethod.MHA:
|
|
||||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
|
||||||
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
|
||||||
return self.forward_normal_chunked_kv(
|
|
||||||
positions, hidden_states, forward_batch
|
|
||||||
)
|
|
||||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
|
||||||
return self.forward_absorb(
|
|
||||||
positions, hidden_states, forward_batch, zero_allocator
|
|
||||||
)
|
|
||||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
|
||||||
return self.forward_absorb_fused_mla_rope(
|
|
||||||
positions, hidden_states, forward_batch
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward_normal(
|
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
zero_allocator: BumpAllocator,
|
||||||
|
):
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
assert (
|
||||||
|
not self.o_proj.reduce_results
|
||||||
|
), "short-circuiting allreduce will lead to hangs"
|
||||||
|
return hidden_states, None, forward_batch, None
|
||||||
|
|
||||||
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||||
|
|
||||||
|
if attn_forward_method == AttnForwardMethod.MHA:
|
||||||
|
inner_state = self.forward_normal_prepare(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||||
|
inner_state = self.forward_normal_chunked_kv_prepare(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||||
|
inner_state = self.forward_absorb_prepare(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||||
|
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return None, attn_forward_method, forward_batch, inner_state
|
||||||
|
|
||||||
|
def forward_core(self, intermediate_state):
|
||||||
|
hidden_states, attn_forward_method, forward_batch, inner_state = (
|
||||||
|
intermediate_state
|
||||||
|
)
|
||||||
|
if inner_state is None:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
if attn_forward_method == AttnForwardMethod.MHA:
|
||||||
|
return self.forward_normal_core(*inner_state)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
||||||
|
return self.forward_normal_chunked_kv_core(*inner_state)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||||
|
return self.forward_absorb_core(*inner_state)
|
||||||
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||||
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_normal_prepare(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
zero_allocator: BumpAllocator,
|
||||||
|
):
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||||
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return q, k, v, forward_batch
|
||||||
|
|
||||||
|
def forward_normal_core(self, q, k, v, forward_batch):
|
||||||
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||||
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_absorb(
|
def forward_absorb_prepare(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = q_nope_out.transpose(0, 1)
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
|
|
||||||
|
def forward_absorb_core(
|
||||||
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
|
):
|
||||||
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||||
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_absorb_fused_mla_rope(
|
def forward_absorb_fused_mla_rope_prepare(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
enable_rope_fusion = (
|
enable_rope_fusion = (
|
||||||
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
||||||
)
|
)
|
||||||
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_input,
|
||||||
|
key_cache_buf,
|
||||||
|
val_cache_buf,
|
||||||
|
attn_output,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
k_pe_output,
|
||||||
|
cos_sin_cache,
|
||||||
|
positions,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_split,
|
||||||
|
sm_scale,
|
||||||
|
enable_rope_fusion,
|
||||||
|
k_input,
|
||||||
|
forward_batch,
|
||||||
|
zero_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_absorb_fused_mla_rope_core(
|
||||||
|
self,
|
||||||
|
q_input,
|
||||||
|
key_cache_buf,
|
||||||
|
val_cache_buf,
|
||||||
|
attn_output,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
k_pe_output,
|
||||||
|
cos_sin_cache,
|
||||||
|
positions,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_split,
|
||||||
|
sm_scale,
|
||||||
|
enable_rope_fusion,
|
||||||
|
k_input,
|
||||||
|
forward_batch,
|
||||||
|
zero_allocator,
|
||||||
|
):
|
||||||
decode_attention_fwd_grouped_rope(
|
decode_attention_fwd_grouped_rope(
|
||||||
q_input,
|
q_input,
|
||||||
key_cache_buf,
|
key_cache_buf,
|
||||||
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
return accum_output
|
return accum_output
|
||||||
|
|
||||||
def forward_normal_chunked_kv(
|
def forward_normal_chunked_kv_prepare(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
zero_allocator: BumpAllocator,
|
||||||
|
):
|
||||||
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
||||||
# To avoid this, we split the kv cache into chunks and process them one after another.
|
# To avoid this, we split the kv cache into chunks and process them one after another.
|
||||||
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
||||||
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return q, k, v, forward_batch
|
||||||
|
|
||||||
|
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
||||||
# Do mha for extended part without prefix
|
# Do mha for extended part without prefix
|
||||||
forward_batch.set_attn_attend_prefix_cache(False)
|
forward_batch.set_attn_attend_prefix_cache(False)
|
||||||
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||||
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_attn(self, state):
|
|
||||||
state.hidden_states_after_attn = self.self_attn(
|
|
||||||
positions=state.positions,
|
|
||||||
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
zero_allocator=state.zero_allocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
def op_comm_prepare_mlp(self, state):
|
def op_comm_prepare_mlp(self, state):
|
||||||
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
||||||
self.layer_communicator.prepare_mlp(
|
self.layer_communicator.prepare_mlp(
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ def compute_layer_operations(
|
|||||||
if not layer.is_layer_sparse:
|
if not layer.is_layer_sparse:
|
||||||
return [
|
return [
|
||||||
layer.op_comm_prepare_attn,
|
layer.op_comm_prepare_attn,
|
||||||
layer.op_attn,
|
layer.self_attn.op_prepare,
|
||||||
|
layer.self_attn.op_core,
|
||||||
layer.op_comm_prepare_mlp,
|
layer.op_comm_prepare_mlp,
|
||||||
layer.op_mlp,
|
layer.op_mlp,
|
||||||
layer.op_comm_postprocess_layer,
|
layer.op_comm_postprocess_layer,
|
||||||
@@ -16,7 +17,8 @@ def compute_layer_operations(
|
|||||||
# Will add TBO operation orders here
|
# Will add TBO operation orders here
|
||||||
return [
|
return [
|
||||||
layer.op_comm_prepare_attn,
|
layer.op_comm_prepare_attn,
|
||||||
layer.op_attn,
|
layer.self_attn.op_prepare,
|
||||||
|
layer.self_attn.op_core,
|
||||||
layer.op_comm_prepare_mlp,
|
layer.op_comm_prepare_mlp,
|
||||||
layer.mlp.op_gate,
|
layer.mlp.op_gate,
|
||||||
layer.mlp.op_shared_experts,
|
layer.mlp.op_shared_experts,
|
||||||
|
|||||||
Reference in New Issue
Block a user