[Feature] Support FA3 backend for MLA (#4831)
This commit is contained in:
@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -58,6 +60,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
|
self.use_mla = (
|
||||||
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
|
) and (not global_server_args_dict["disable_mla"])
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize forward metadata to cache repetitive calculations."""
|
"""Initialize forward metadata to cache repetitive calculations."""
|
||||||
@@ -117,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
cache_loc = (
|
|
||||||
forward_batch.out_cache_loc
|
|
||||||
if not layer.is_cross_attention
|
|
||||||
else forward_batch.encoder_out_cache_loc
|
|
||||||
)
|
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
cache_loc = (
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
)
|
)
|
||||||
|
if not self.use_mla:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
# Use precomputed metadata
|
# Use precomputed metadata
|
||||||
metadata = self.forward_metadata
|
metadata = self.forward_metadata
|
||||||
|
|
||||||
# # Use Flash Attention for prefill
|
|
||||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
# here is two side inclusive
|
# here is two side inclusive
|
||||||
@@ -142,36 +154,72 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
|
|
||||||
key_cache = key_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
|
||||||
)
|
|
||||||
value_cache = value_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
page_table = metadata.page_table
|
page_table = metadata.page_table
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
# # Use Flash Attention for prefill
|
||||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
if not self.use_mla:
|
||||||
k_cache=key_cache,
|
# Do multi-head attention
|
||||||
v_cache=value_cache,
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
page_table=page_table,
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
key_cache = key_cache.view(
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
)
|
||||||
max_seqlen_q=metadata.max_seq_len_q,
|
value_cache = value_cache.view(
|
||||||
softmax_scale=layer.scaling,
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||||
causal=True,
|
)
|
||||||
window_size=window_size,
|
o = flash_attn_with_kvcache(
|
||||||
softcap=layer.logit_cap,
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k_descale=layer.k_scale,
|
k_cache=key_cache,
|
||||||
v_descale=layer.v_scale,
|
v_cache=value_cache,
|
||||||
)
|
page_table=page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=metadata.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Do absorbed multi-latent attention
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||||
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||||
|
k_rope_cache = k_rope.view(
|
||||||
|
-1,
|
||||||
|
self.page_size,
|
||||||
|
layer.tp_k_head_num,
|
||||||
|
layer.head_dim - layer.v_head_dim,
|
||||||
|
)
|
||||||
|
c_kv_cache = c_kv.view(
|
||||||
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||||
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q_rope,
|
||||||
|
k_cache=k_rope_cache,
|
||||||
|
v_cache=c_kv_cache,
|
||||||
|
qv=q_nope,
|
||||||
|
page_table=page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=metadata.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
@@ -184,24 +232,29 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention using precomputed metadata."""
|
"""Forward pass with FlashAttention using precomputed metadata."""
|
||||||
# Save KV cache if needed
|
# Save KV cache if needed
|
||||||
if k is not None and v is not None and save_kv_cache:
|
if k is not None:
|
||||||
cache_loc = (
|
assert v is not None
|
||||||
forward_batch.out_cache_loc
|
if save_kv_cache:
|
||||||
if not layer.is_cross_attention
|
cache_loc = (
|
||||||
else forward_batch.encoder_out_cache_loc
|
forward_batch.out_cache_loc
|
||||||
)
|
if not layer.is_cross_attention
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
else forward_batch.encoder_out_cache_loc
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
)
|
||||||
)
|
if not self.use_mla:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
# Get KV cache
|
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
# Use precomputed metadata
|
# Use precomputed metadata
|
||||||
metadata = self.forward_metadata
|
metadata = self.forward_metadata
|
||||||
|
|
||||||
# Pre-reshape query tensor
|
|
||||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
# here is two side inclusive
|
# here is two side inclusive
|
||||||
@@ -210,33 +263,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
# Run attention with precomputed values
|
|
||||||
key_cache = key_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
|
||||||
)
|
|
||||||
value_cache = value_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
page_table = metadata.page_table
|
page_table = metadata.page_table
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
if not self.use_mla:
|
||||||
q=q_reshaped,
|
# Do multi-head attention
|
||||||
k_cache=key_cache,
|
|
||||||
v_cache=value_cache,
|
# Get KV cache
|
||||||
page_table=page_table,
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
key_cache = key_cache.view(
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||||
max_seqlen_q=1,
|
)
|
||||||
softmax_scale=layer.scaling,
|
value_cache = value_cache.view(
|
||||||
causal=True,
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||||
window_size=window_size,
|
)
|
||||||
softcap=layer.logit_cap,
|
|
||||||
k_descale=layer.k_scale,
|
# Pre-reshape query tensor
|
||||||
v_descale=layer.v_scale,
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
)
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
# Run attention with precomputed values
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q_reshaped,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
page_table=page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Do absorbed multi-latent attention
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
||||||
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
||||||
|
k_rope_cache = k_rope.view(
|
||||||
|
-1,
|
||||||
|
self.page_size,
|
||||||
|
layer.tp_k_head_num,
|
||||||
|
layer.head_dim - layer.v_head_dim,
|
||||||
|
)
|
||||||
|
c_kv_cache = c_kv.view(
|
||||||
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q_rope,
|
||||||
|
k_cache=k_rope_cache,
|
||||||
|
v_cache=c_kv_cache,
|
||||||
|
qv=q_nope,
|
||||||
|
page_table=page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
"""Initialize CUDA graph state for the attention backend.
|
"""Initialize CUDA graph state for the attention backend.
|
||||||
@@ -286,7 +385,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||||
req_pool_indices, :
|
req_pool_indices, :
|
||||||
]
|
]
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
# Precompute cumulative sequence lengths
|
# Precompute cumulative sequence lengths
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
|||||||
@@ -230,6 +230,10 @@ class ModelRunner:
|
|||||||
elif server_args.enable_flashmla:
|
elif server_args.enable_flashmla:
|
||||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||||
server_args.attention_backend = "flashmla"
|
server_args.attention_backend = "flashmla"
|
||||||
|
elif server_args.attention_backend == "fa3":
|
||||||
|
logger.info(
|
||||||
|
f"MLA optimization is turned on. Use flash attention 3 backend."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||||
server_args.attention_backend = "triton"
|
server_args.attention_backend = "triton"
|
||||||
@@ -879,7 +883,7 @@ class ModelRunner:
|
|||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
|
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
|
|||||||
@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||||
"flashinfer_mla_disable_ragged"
|
"flashinfer_mla_disable_ragged"
|
||||||
]
|
]
|
||||||
|
self.attention_backend = global_server_args_dict["attention_backend"]
|
||||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||||
|
|
||||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||||
@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||||
)
|
)
|
||||||
|
elif self.attention_backend == "fa3":
|
||||||
|
# Flash Attention: Keep absorbing for all extend/decode
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user