Cleaning indexer for DeepSeek V3.2 (#11682)
This commit is contained in:
@@ -17,7 +17,7 @@ if is_cuda():
|
||||
except ImportError as e:
|
||||
deep_gemm = e
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
@@ -168,43 +168,6 @@ class Indexer(CustomOp):
|
||||
self.scale_fmt = scale_fmt
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def _forward_fake(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
):
|
||||
bs = x.shape[0]
|
||||
assert self.index_topk == 2048
|
||||
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
|
||||
None, ...
|
||||
].repeat(bs, 1)
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
and forward_batch.seq_lens_cpu is not None
|
||||
)
|
||||
which = 0
|
||||
for i, (kv_len, qo_len) in enumerate(
|
||||
zip(
|
||||
forward_batch.seq_lens_cpu.tolist(),
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
strict=True,
|
||||
)
|
||||
):
|
||||
for j in range(kv_len - qo_len, kv_len):
|
||||
ans[which, j + 1 :] = -1
|
||||
which += 1
|
||||
assert which == ans.shape[0]
|
||||
else:
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
|
||||
ans[i, seq_len:] = -1
|
||||
|
||||
return ans
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights, _ = self.weights_proj(x)
|
||||
@@ -404,7 +367,7 @@ class Indexer(CustomOp):
|
||||
|
||||
return topk_result
|
||||
|
||||
def forward_indexer_bs_1(
|
||||
def forward_indexer(
|
||||
self,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
@@ -485,20 +448,9 @@ class Indexer(CustomOp):
|
||||
q_len_start = q_len_end
|
||||
|
||||
topk_indices = torch.cat(topk_indices_list, dim=0)
|
||||
|
||||
return topk_indices
|
||||
|
||||
def forward_indexer(
|
||||
self,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk: int,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
|
||||
|
||||
def _forward(
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
@@ -530,9 +482,6 @@ class Indexer(CustomOp):
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if not NSA_USE_REAL_INDEXER: # temporary
|
||||
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
|
||||
|
||||
if enable_dual_stream:
|
||||
@@ -588,19 +537,8 @@ class Indexer(CustomOp):
|
||||
topk=self.index_topk,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
return topk_result
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self._forward(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# temp NSA debugging environ
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
|
||||
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
|
||||
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user