[ROCm] Enable Fused MLA Triton kernel for DeepSeekV3 (#3237)
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3758d209a0
commit
6ce9dbe828
160
python/sglang/srt/models/deepseek_v2.py
Normal file → Executable file
160
python/sglang/srt/models/deepseek_v2.py
Normal file → Executable file
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,9 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -533,7 +537,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if no_absorb():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
if is_hip_:
|
||||
if (
|
||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_absorb_fused_mla_rope(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
@@ -652,6 +667,149 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_absorb_fused_mla_rope(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
enable_rope_fusion = (
|
||||
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
||||
)
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
)
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., : self.kv_lora_rank] = v_input
|
||||
|
||||
if not enable_rope_fusion:
|
||||
k_pe = k_input[..., self.kv_lora_rank :]
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
k_pe_output = None
|
||||
else:
|
||||
k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
|
||||
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
|
||||
# attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
# Use Fused ROPE with use_rope=OFF.
|
||||
attn_output = torch.empty(
|
||||
(q_len, self.num_local_heads, self.kv_lora_rank),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
|
||||
forward_batch.attn_backend.forward_metadata
|
||||
)
|
||||
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||
num_kv_split = forward_batch.attn_backend.num_kv_splits
|
||||
sm_scale = self.attn_mqa.scaling
|
||||
if attn_logits is None:
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
forward_batch.batch_size,
|
||||
self.num_local_heads,
|
||||
num_kv_split,
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# save current latent cache.
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
||||
)
|
||||
key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
self.attn_mqa.layer_id
|
||||
)
|
||||
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
||||
|
||||
decode_attention_fwd_grouped_rope(
|
||||
q_input,
|
||||
key_cache_buf,
|
||||
val_cache_buf,
|
||||
attn_output,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
k_pe_output,
|
||||
self.kv_lora_rank,
|
||||
self.rotary_emb.rotary_dim,
|
||||
cos_sin_cache,
|
||||
positions,
|
||||
attn_logits,
|
||||
num_kv_split,
|
||||
sm_scale,
|
||||
logit_cap=self.attn_mqa.logit_cap,
|
||||
use_rope=enable_rope_fusion,
|
||||
is_neox_style=self.rotary_emb.is_neox_style,
|
||||
)
|
||||
|
||||
if enable_rope_fusion:
|
||||
k_input[..., self.kv_lora_rank :] = k_pe_output
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
self.w_vc,
|
||||
attn_output_scale,
|
||||
self.w_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
else:
|
||||
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
||||
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(
|
||||
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
||||
|
||||
Reference in New Issue
Block a user