diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 8921e1214..ebb5b85da 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -17,7 +17,7 @@ if is_cuda(): except ImportError as e: deep_gemm = e -from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER +from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization import deep_gemm_wrapper @@ -168,43 +168,6 @@ class Indexer(CustomOp): self.scale_fmt = scale_fmt self.softmax_scale = self.head_dim**-0.5 - def _forward_fake( - self, - x: torch.Tensor, - q_lora: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - layer_id: int, - ): - bs = x.shape[0] - assert self.index_topk == 2048 - ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[ - None, ... - ].repeat(bs, 1) - if forward_batch.forward_mode.is_extend(): - assert ( - forward_batch.extend_seq_lens_cpu is not None - and forward_batch.seq_lens_cpu is not None - ) - which = 0 - for i, (kv_len, qo_len) in enumerate( - zip( - forward_batch.seq_lens_cpu.tolist(), - forward_batch.extend_seq_lens_cpu, - strict=True, - ) - ): - for j in range(kv_len - qo_len, kv_len): - ans[which, j + 1 :] = -1 - which += 1 - assert which == ans.shape[0] - else: - assert forward_batch.seq_lens_cpu is not None - for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()): - ans[i, seq_len:] = -1 - - return ans - @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights, _ = self.weights_proj(x) @@ -404,7 +367,7 @@ class Indexer(CustomOp): return topk_result - def forward_indexer_bs_1( + def forward_indexer( self, q_fp8: torch.Tensor, weights: torch.Tensor, @@ -485,20 +448,9 @@ class Indexer(CustomOp): q_len_start = q_len_end topk_indices = torch.cat(topk_indices_list, dim=0) - return topk_indices - def forward_indexer( - self, - q_fp8: torch.Tensor, - weights: torch.Tensor, - forward_batch: ForwardBatch, - topk: int, - layer_id: int, - ) -> Optional[torch.Tensor]: - return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id) - - def _forward( + def forward_cuda( self, x: torch.Tensor, q_lora: torch.Tensor, @@ -530,9 +482,6 @@ class Indexer(CustomOp): if metadata is None: return None - if not NSA_USE_REAL_INDEXER: # temporary - return self._forward_fake(x, q_lora, positions, forward_batch, layer_id) - query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream) if enable_dual_stream: @@ -588,19 +537,8 @@ class Indexer(CustomOp): topk=self.index_topk, layer_id=layer_id, ) - return topk_result - def forward_cuda( - self, - x: torch.Tensor, - q_lora: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - layer_id: int, - ) -> Optional[torch.Tensor]: - return self._forward(x, q_lora, positions, forward_batch, layer_id) - def forward_npu( self, x: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py index 348f1b736..e2d0da583 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -1,7 +1,6 @@ # temp NSA debugging environ from sglang.srt.utils import get_bool_env_var -NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true") NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true") NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")