[Performance] Qwen3-Next: speed up update_mamba_state_after_mtp_verify by 10x; e2e up to 3.54% faster (#10586)
This commit is contained in:
@@ -583,36 +583,15 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
|
||||
# Compute common indices once to avoid duplication
|
||||
last_steps_all = (accepted_length - 1).to(torch.int64)
|
||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
||||
last_steps = last_steps_all[valid_mask].to(torch.int64)
|
||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
|
||||
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
|
||||
|
||||
if valid_state_indices.numel() > 0:
|
||||
chunk = 256
|
||||
num_valid = valid_state_indices.numel()
|
||||
# scatter into ssm_states at the chosen cache lines
|
||||
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
||||
:, valid_state_indices, last_steps
|
||||
].to(ssm_states.dtype, copy=False)
|
||||
|
||||
# SSM state updates
|
||||
for i in range(0, num_valid, chunk):
|
||||
idx = valid_state_indices[i : i + chunk]
|
||||
steps = last_steps[i : i + chunk]
|
||||
# per (cache line, step)
|
||||
for j in range(idx.numel()):
|
||||
ci = idx[j].item()
|
||||
st = steps[j].item()
|
||||
ssm_states[:, ci, :].copy_(
|
||||
intermediate_state_cache[:, ci, st].to(
|
||||
ssm_states.dtype, copy=False
|
||||
)
|
||||
)
|
||||
|
||||
# Conv window updates
|
||||
for i in range(0, num_valid, chunk):
|
||||
idx = valid_state_indices[i : i + chunk]
|
||||
steps = last_steps[i : i + chunk]
|
||||
for j in range(idx.numel()):
|
||||
ci = idx[j].item()
|
||||
st = steps[j].item()
|
||||
conv_states[:, ci, :, :].copy_(
|
||||
intermediate_conv_window_cache[:, ci, st].to(
|
||||
conv_states.dtype, copy=False
|
||||
)
|
||||
)
|
||||
# Scatter into conv_states at the chosen cache lines
|
||||
conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
|
||||
:, valid_state_indices, last_steps
|
||||
].to(conv_states.dtype, copy=False)
|
||||
|
||||
Reference in New Issue
Block a user