[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300 tok/s to 341 tok/s (#10335)
Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
This commit is contained in:
@@ -125,16 +125,6 @@ class MambaPool:
|
||||
device=device,
|
||||
)
|
||||
if speculative_num_draft_tokens is not None:
|
||||
mixed_qkv_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate SSM states per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||
intermediate_ssm_state_cache = torch.empty(
|
||||
@@ -149,11 +139,24 @@ class MambaPool:
|
||||
dtype=ssm_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||
intermediate_conv_window_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
conv_state_shape[1],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = (
|
||||
conv_state,
|
||||
temporal_state,
|
||||
mixed_qkv_cache,
|
||||
intermediate_ssm_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
Reference in New Issue
Block a user