[Performance] Qwen3-Next: replace arange to cached query_start_loc_li… (#10553)
This commit is contained in:
@@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata: ForwardMetadata = None
|
self.forward_metadata: ForwardMetadata = None
|
||||||
self.state_indices_list = []
|
self.state_indices_list = []
|
||||||
self.query_start_loc_list = []
|
self.query_start_loc_list = []
|
||||||
|
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
||||||
@classmethod
|
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
||||||
@lru_cache(maxsize=128)
|
|
||||||
def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
|
|
||||||
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
|
|
||||||
device = torch.device(device_str)
|
|
||||||
return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
query_start_loc = self._get_cached_arange(bs, str(self.device))
|
query_start_loc = torch.arange(
|
||||||
|
0, bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
if forward_batch.forward_mode.is_target_verify():
|
if forward_batch.forward_mode.is_target_verify():
|
||||||
query_start_loc = torch.arange(
|
query_start_loc = torch.arange(
|
||||||
@@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
|
assert (
|
||||||
|
max_num_tokens % max_bs == 0
|
||||||
|
), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
|
||||||
|
verify_step = max_num_tokens / max_bs
|
||||||
for i in range(max_bs):
|
for i in range(max_bs):
|
||||||
self.state_indices_list.append(
|
self.state_indices_list.append(
|
||||||
torch.full(
|
torch.full(
|
||||||
@@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
self.query_start_loc_list.append(
|
self.query_start_loc_list.append(
|
||||||
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
|
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
|
||||||
)
|
)
|
||||||
|
self.cached_cuda_graph_decode_query_start_loc = torch.arange(
|
||||||
|
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.cached_cuda_graph_verify_query_start_loc = torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * verify_step + 1,
|
||||||
|
step=verify_step,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
self.query_start_loc_list[bs - 1].copy_(
|
||||||
|
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
||||||
|
)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
self.query_start_loc_list[bs - 1].copy_(
|
self.query_start_loc_list[bs - 1].copy_(
|
||||||
torch.arange(
|
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
||||||
0,
|
|
||||||
bs * spec_info.draft_token_num + 1,
|
|
||||||
step=spec_info.draft_token_num,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||||
@@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend):
|
|||||||
mamba_indices[bs - num_padding :] = -1
|
mamba_indices[bs - num_padding :] = -1
|
||||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
if num_padding == 0:
|
||||||
if num_padding > 0:
|
self.query_start_loc_list[bs - 1].copy_(
|
||||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
|
self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
|
||||||
elif forward_mode.is_target_verify():
|
|
||||||
self.query_start_loc_list[bs - 1].copy_(
|
|
||||||
torch.arange(
|
|
||||||
0,
|
|
||||||
bs * spec_info.draft_token_num + 1,
|
|
||||||
step=spec_info.draft_token_num,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
if num_padding > 0:
|
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
||||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = (
|
self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
|
||||||
|
)
|
||||||
|
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
||||||
bs - num_padding
|
bs - num_padding
|
||||||
) * spec_info.draft_token_num
|
)
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
if num_padding == 0:
|
||||||
|
self.query_start_loc_list[bs - 1].copy_(
|
||||||
|
self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
|
||||||
|
self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
|
||||||
|
)
|
||||||
|
self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
|
||||||
|
(bs - num_padding) * spec_info.draft_token_num
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user