Rename flashmla kernel options of nsa backend for better readability (#11876)
This commit is contained in:
@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
_NSA_IMPL_T: TypeAlias = Literal[
|
||||
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
||||
]
|
||||
_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
|
||||
|
||||
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
||||
NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||
@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
||||
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
||||
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
||||
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
|
||||
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
|
||||
|
||||
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
||||
|
||||
@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1,
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
if NSA_DECODE_IMPL == "flashmla_kv"
|
||||
else None
|
||||
),
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
),
|
||||
seq_len_q=1,
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
if NSA_DECODE_IMPL == "flashmla_kv"
|
||||
else None
|
||||
),
|
||||
}
|
||||
@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
|
||||
seqlens_expanded = cache_seqlens_int32
|
||||
nsa_extend_seq_lens_list = [1] * num_tokens
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, num_tokens + 1))
|
||||
@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||
@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
nsa_extend_seq_lens_list = [1] * bs
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||
@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
else:
|
||||
assert metadata.real_page_table is metadata.page_table_1
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||
flashmla_metadata = metadata.flashmla_metadata.slice(
|
||||
slice(0, seqlens_expanded_size + 1)
|
||||
)
|
||||
@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
elif NSA_PREFILL_IMPL == "flashmla_sparse":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
return self._forward_flashmla_sparse(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
||||
elif NSA_PREFILL_IMPL == "flashmla_kv":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
return self._forward_flashmla_kv(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
page_size=1,
|
||||
)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_prefill":
|
||||
if NSA_DECODE_IMPL == "flashmla_sparse":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
return self._forward_flashmla_sparse(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "flashmla_decode":
|
||||
elif NSA_DECODE_IMPL == "flashmla_kv":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
return self._forward_flashmla_kv(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
return o # type: ignore
|
||||
|
||||
def _forward_flashmla_prefill(
|
||||
def _forward_flashmla_sparse(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_flashmla_decode(
|
||||
def _forward_flashmla_kv(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
|
||||
@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
||||
|
||||
DEFAULT_LORA_EVICTION_POLICY = "lru"
|
||||
|
||||
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
|
||||
NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
|
||||
|
||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||
|
||||
@@ -324,8 +324,8 @@ class ServerArgs:
|
||||
sampling_backend: Optional[str] = None
|
||||
grammar_backend: Optional[str] = None
|
||||
mm_attention_backend: Optional[str] = None
|
||||
nsa_prefill: str = "flashmla_prefill"
|
||||
nsa_decode: str = "fa3"
|
||||
nsa_prefill_backend: str = "flashmla_sparse"
|
||||
nsa_decode_backend: str = "fa3"
|
||||
|
||||
# Speculative decoding
|
||||
enable_beta_spec: bool = False
|
||||
@@ -1024,10 +1024,10 @@ class ServerArgs:
|
||||
logger.warning("Setting KV cache dtype to fp8.")
|
||||
|
||||
if self.kv_cache_dtype == "fp8_e4m3":
|
||||
self.nsa_prefill = "flashmla_decode"
|
||||
self.nsa_decode = "flashmla_decode"
|
||||
self.nsa_prefill_backend = "flashmla_kv"
|
||||
self.nsa_decode_backend = "flashmla_kv"
|
||||
logger.warning(
|
||||
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
|
||||
"Setting NSA backend to flashmla_kv for FP8 KV Cache."
|
||||
)
|
||||
|
||||
# Logging env vars for NSA
|
||||
@@ -2356,14 +2356,14 @@ class ServerArgs:
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nsa-prefill",
|
||||
default=ServerArgs.nsa_prefill,
|
||||
"--nsa-prefill-backend",
|
||||
default=ServerArgs.nsa_prefill_backend,
|
||||
type=str,
|
||||
choices=NSA_CHOICES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nsa-decode",
|
||||
default=ServerArgs.nsa_decode,
|
||||
"--nsa-decode-backend",
|
||||
default=ServerArgs.nsa_decode_backend,
|
||||
type=str,
|
||||
choices=NSA_CHOICES,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user