From 4e6c4923a0c60ffc81fee72e95d6d159fc6b484b Mon Sep 17 00:00:00 2001 From: Binyao Jiang Date: Thu, 18 Sep 2025 17:13:59 -0700 Subject: [PATCH] [Performance] Qwen3-Next: speed up update_mamba_state_after_mtp_verify by 10x; e2e up to 3.54% faster (#10586) --- .../attention/hybrid_linear_attn_backend.py | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 58831a5a6..699e5af68 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -583,36 +583,15 @@ class HybridLinearAttnBackend(AttentionBackend): # Compute common indices once to avoid duplication last_steps_all = (accepted_length - 1).to(torch.int64) - valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) - last_steps = last_steps_all[valid_mask].to(torch.int64) + valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N] + last_steps = last_steps_all[valid_mask].to(torch.int64) # [N] - if valid_state_indices.numel() > 0: - chunk = 256 - num_valid = valid_state_indices.numel() + # scatter into ssm_states at the chosen cache lines + ssm_states[:, valid_state_indices, :] = intermediate_state_cache[ + :, valid_state_indices, last_steps + ].to(ssm_states.dtype, copy=False) - # SSM state updates - for i in range(0, num_valid, chunk): - idx = valid_state_indices[i : i + chunk] - steps = last_steps[i : i + chunk] - # per (cache line, step) - for j in range(idx.numel()): - ci = idx[j].item() - st = steps[j].item() - ssm_states[:, ci, :].copy_( - intermediate_state_cache[:, ci, st].to( - ssm_states.dtype, copy=False - ) - ) - - # Conv window updates - for i in range(0, num_valid, chunk): - idx = valid_state_indices[i : i + chunk] - steps = last_steps[i : i + chunk] - for j in range(idx.numel()): - ci = idx[j].item() - st = steps[j].item() - conv_states[:, ci, :, :].copy_( - intermediate_conv_window_cache[:, ci, st].to( - conv_states.dtype, copy=False - ) - ) + # Scatter into conv_states at the chosen cache lines + conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[ + :, valid_state_indices, last_steps + ].to(conv_states.dtype, copy=False)