Rename flashmla kernel options of nsa backend for better readability (#11876)
This commit is contained in:
@@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
|
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
|
||||||
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
|
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
|
||||||
| `--mm-attention-backend` | Set multimodal attention backend. | None |
|
| `--mm-attention-backend` | Set multimodal attention backend. | None |
|
||||||
|
| `--nsa-prefill-backend` | Prefill attention implementation for nsa backend. | `flashmla_sparse` |
|
||||||
|
| `--nsa-decode-backend` | Decode attention implementation for nsa backend. | `flashmla_kv` |
|
||||||
|
|
||||||
## Speculative decoding
|
## Speculative decoding
|
||||||
|
|
||||||
|
|||||||
@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_NSA_IMPL_T: TypeAlias = Literal[
|
_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
|
||||||
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
|
||||||
]
|
|
||||||
|
|
||||||
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
||||||
NSA_DECODE_IMPL: _NSA_IMPL_T
|
NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||||
@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
||||||
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
|
||||||
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
|
||||||
|
|
||||||
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
||||||
|
|
||||||
@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
cache_seqlens=nsa_cache_seqlens_int32,
|
cache_seqlens=nsa_cache_seqlens_int32,
|
||||||
seq_len_q=1,
|
seq_len_q=1,
|
||||||
)
|
)
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
if NSA_DECODE_IMPL == "flashmla_kv"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||||
@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
seq_len_q=1,
|
seq_len_q=1,
|
||||||
)
|
)
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
if NSA_DECODE_IMPL == "flashmla_kv"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
seqlens_expanded = cache_seqlens_int32
|
seqlens_expanded = cache_seqlens_int32
|
||||||
nsa_extend_seq_lens_list = [1] * num_tokens
|
nsa_extend_seq_lens_list = [1] * num_tokens
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||||
"flashmla_metadata"
|
"flashmla_metadata"
|
||||||
].slice(slice(0, num_tokens + 1))
|
].slice(slice(0, num_tokens + 1))
|
||||||
@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
|
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
|
||||||
|
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||||
"flashmla_metadata"
|
"flashmla_metadata"
|
||||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||||
@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
nsa_extend_seq_lens_list = [1] * bs
|
nsa_extend_seq_lens_list = [1] * bs
|
||||||
|
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||||
"flashmla_metadata"
|
"flashmla_metadata"
|
||||||
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
||||||
@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
assert metadata.real_page_table is metadata.page_table_1
|
assert metadata.real_page_table is metadata.page_table_1
|
||||||
|
|
||||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
if NSA_DECODE_IMPL == "flashmla_kv":
|
||||||
flashmla_metadata = metadata.flashmla_metadata.slice(
|
flashmla_metadata = metadata.flashmla_metadata.slice(
|
||||||
slice(0, seqlens_expanded_size + 1)
|
slice(0, seqlens_expanded_size + 1)
|
||||||
)
|
)
|
||||||
@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
v_head_dim=layer.v_head_dim,
|
v_head_dim=layer.v_head_dim,
|
||||||
)
|
)
|
||||||
elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
elif NSA_PREFILL_IMPL == "flashmla_sparse":
|
||||||
if q_rope is not None:
|
if q_rope is not None:
|
||||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
return self._forward_flashmla_prefill(
|
return self._forward_flashmla_sparse(
|
||||||
q_all=q_all,
|
q_all=q_all,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
page_table_1=page_table_1,
|
page_table_1=page_table_1,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
v_head_dim=layer.v_head_dim,
|
v_head_dim=layer.v_head_dim,
|
||||||
)
|
)
|
||||||
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
elif NSA_PREFILL_IMPL == "flashmla_kv":
|
||||||
if q_rope is not None:
|
if q_rope is not None:
|
||||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
return self._forward_flashmla_decode(
|
return self._forward_flashmla_kv(
|
||||||
q_all=q_all,
|
q_all=q_all,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
page_size=1,
|
page_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if NSA_DECODE_IMPL == "flashmla_prefill":
|
if NSA_DECODE_IMPL == "flashmla_sparse":
|
||||||
if q_rope is not None:
|
if q_rope is not None:
|
||||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
return self._forward_flashmla_prefill(
|
return self._forward_flashmla_sparse(
|
||||||
q_all=q_all,
|
q_all=q_all,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
page_table_1=page_table_1,
|
page_table_1=page_table_1,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
v_head_dim=layer.v_head_dim,
|
v_head_dim=layer.v_head_dim,
|
||||||
)
|
)
|
||||||
elif NSA_DECODE_IMPL == "flashmla_decode":
|
elif NSA_DECODE_IMPL == "flashmla_kv":
|
||||||
if q_rope is not None:
|
if q_rope is not None:
|
||||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
return self._forward_flashmla_decode(
|
return self._forward_flashmla_kv(
|
||||||
q_all=q_all,
|
q_all=q_all,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
return o # type: ignore
|
return o # type: ignore
|
||||||
|
|
||||||
def _forward_flashmla_prefill(
|
def _forward_flashmla_sparse(
|
||||||
self,
|
self,
|
||||||
q_all: torch.Tensor,
|
q_all: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def _forward_flashmla_decode(
|
def _forward_flashmla_kv(
|
||||||
self,
|
self,
|
||||||
q_all: torch.Tensor,
|
q_all: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
|||||||
|
|
||||||
DEFAULT_LORA_EVICTION_POLICY = "lru"
|
DEFAULT_LORA_EVICTION_POLICY = "lru"
|
||||||
|
|
||||||
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
|
NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
|
||||||
|
|
||||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||||
|
|
||||||
@@ -324,8 +324,8 @@ class ServerArgs:
|
|||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
grammar_backend: Optional[str] = None
|
grammar_backend: Optional[str] = None
|
||||||
mm_attention_backend: Optional[str] = None
|
mm_attention_backend: Optional[str] = None
|
||||||
nsa_prefill: str = "flashmla_prefill"
|
nsa_prefill_backend: str = "flashmla_sparse"
|
||||||
nsa_decode: str = "fa3"
|
nsa_decode_backend: str = "fa3"
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
enable_beta_spec: bool = False
|
enable_beta_spec: bool = False
|
||||||
@@ -1024,10 +1024,10 @@ class ServerArgs:
|
|||||||
logger.warning("Setting KV cache dtype to fp8.")
|
logger.warning("Setting KV cache dtype to fp8.")
|
||||||
|
|
||||||
if self.kv_cache_dtype == "fp8_e4m3":
|
if self.kv_cache_dtype == "fp8_e4m3":
|
||||||
self.nsa_prefill = "flashmla_decode"
|
self.nsa_prefill_backend = "flashmla_kv"
|
||||||
self.nsa_decode = "flashmla_decode"
|
self.nsa_decode_backend = "flashmla_kv"
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
|
"Setting NSA backend to flashmla_kv for FP8 KV Cache."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Logging env vars for NSA
|
# Logging env vars for NSA
|
||||||
@@ -2356,14 +2356,14 @@ class ServerArgs:
|
|||||||
help="Set multimodal attention backend.",
|
help="Set multimodal attention backend.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nsa-prefill",
|
"--nsa-prefill-backend",
|
||||||
default=ServerArgs.nsa_prefill,
|
default=ServerArgs.nsa_prefill_backend,
|
||||||
type=str,
|
type=str,
|
||||||
choices=NSA_CHOICES,
|
choices=NSA_CHOICES,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nsa-decode",
|
"--nsa-decode-backend",
|
||||||
default=ServerArgs.nsa_decode,
|
default=ServerArgs.nsa_decode_backend,
|
||||||
type=str,
|
type=str,
|
||||||
choices=NSA_CHOICES,
|
choices=NSA_CHOICES,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user