[Performance] Qwen3-Next: speed up update_mamba_state_after_mtp_verify by 10x; e2e up to 3.54% faster (#10586)
This commit is contained in:
@@ -583,36 +583,15 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Compute common indices once to avoid duplication
|
# Compute common indices once to avoid duplication
|
||||||
last_steps_all = (accepted_length - 1).to(torch.int64)
|
last_steps_all = (accepted_length - 1).to(torch.int64)
|
||||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
|
||||||
last_steps = last_steps_all[valid_mask].to(torch.int64)
|
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
|
||||||
|
|
||||||
if valid_state_indices.numel() > 0:
|
# scatter into ssm_states at the chosen cache lines
|
||||||
chunk = 256
|
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
||||||
num_valid = valid_state_indices.numel()
|
:, valid_state_indices, last_steps
|
||||||
|
].to(ssm_states.dtype, copy=False)
|
||||||
|
|
||||||
# SSM state updates
|
# Scatter into conv_states at the chosen cache lines
|
||||||
for i in range(0, num_valid, chunk):
|
conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
|
||||||
idx = valid_state_indices[i : i + chunk]
|
:, valid_state_indices, last_steps
|
||||||
steps = last_steps[i : i + chunk]
|
].to(conv_states.dtype, copy=False)
|
||||||
# per (cache line, step)
|
|
||||||
for j in range(idx.numel()):
|
|
||||||
ci = idx[j].item()
|
|
||||||
st = steps[j].item()
|
|
||||||
ssm_states[:, ci, :].copy_(
|
|
||||||
intermediate_state_cache[:, ci, st].to(
|
|
||||||
ssm_states.dtype, copy=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conv window updates
|
|
||||||
for i in range(0, num_valid, chunk):
|
|
||||||
idx = valid_state_indices[i : i + chunk]
|
|
||||||
steps = last_steps[i : i + chunk]
|
|
||||||
for j in range(idx.numel()):
|
|
||||||
ci = idx[j].item()
|
|
||||||
st = steps[j].item()
|
|
||||||
conv_states[:, ci, :, :].copy_(
|
|
||||||
intermediate_conv_window_cache[:, ci, st].to(
|
|
||||||
conv_states.dtype, copy=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user