[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300 tok/s to 341 tok/s (#10335)

Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
This commit is contained in:
Stefan He
2025-09-11 11:59:48 -07:00
committed by GitHub
parent 4a0e0be2a2
commit 6c18ab46a2
5 changed files with 1145 additions and 294 deletions

View File

@@ -125,16 +125,6 @@ class MambaPool:
device=device,
)
if speculative_num_draft_tokens is not None:
mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
),
dtype=conv_dtype,
device="cuda",
)
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty(
@@ -149,11 +139,24 @@ class MambaPool:
dtype=ssm_dtype,
device="cuda",
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
mixed_qkv_cache,
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
)
else:
self.mamba_cache = (conv_state, temporal_state)