[Performance] Qwen3-Next: replace arange to cached query_start_loc_li… (#10553)
This commit is contained in:
@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend):
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
self.state_indices_list = []
|
||||
self.query_start_loc_list = []
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
|
||||
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
|
||||
device = torch.device(device_str)
|
||||
return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
|
||||
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
||||
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
bs = forward_batch.batch_size
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
query_start_loc = self._get_cached_arange(bs, str(self.device))
|
||||
query_start_loc = torch.arange(
|
||||
0, bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
query_start_loc = torch.arange(
|
||||
@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
assert (
|
||||
max_num_tokens % max_bs == 0
|
||||
), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
|
||||
verify_step = max_num_tokens / max_bs
|
||||
for i in range(max_bs):
|
||||
self.state_indices_list.append(
|
||||
torch.full(
|
||||
@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend):
|
||||
self.query_start_loc_list.append(
|
||||
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
|
||||
)
|
||||
self.cached_cuda_graph_decode_query_start_loc = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cached_cuda_graph_verify_query_start_loc = torch.arange(
|
||||
0,
|
||||
max_bs * verify_step + 1,
|
||||
step=verify_step,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
torch.arange(
|
||||
0,
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend):
|
||||
mamba_indices[bs - num_padding :] = -1
|
||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
||||
if num_padding > 0:
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
|
||||
elif forward_mode.is_target_verify():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
torch.arange(
|
||||
0,
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
if num_padding == 0:
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
||||
)
|
||||
)
|
||||
if num_padding > 0:
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = (
|
||||
else:
|
||||
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
||||
self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
|
||||
)
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
||||
bs - num_padding
|
||||
) * spec_info.draft_token_num
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
if num_padding == 0:
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
||||
)
|
||||
else:
|
||||
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
||||
self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
|
||||
)
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
||||
(bs - num_padding) * spec_info.draft_token_num
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user