[bugfix] fix the complex and potentially problematic generate_kv_idx. (#5957)
### What this PR does / why we need it?
In long-sequence scenarios, the chunked-prefill component may encounter
dimension misalignment issues, which previously occurred during
precision testing on the code_generate_lite dataset. This PR removes
redundant computations and instead derives the value using existing
results and straightforward calculations.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -87,7 +87,6 @@ def test_models_chunked_prefill_mixed_length_prompts_including_1_token(
|
|||||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
|
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"
|
||||||
})
|
})
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.skip(reason="skip for bad adaptability with main2main")
|
|
||||||
def test_models_chunked_prefill_with_empty_kvcache(model: str):
|
def test_models_chunked_prefill_with_empty_kvcache(model: str):
|
||||||
TEST_ROPE_PARAMETERS = {
|
TEST_ROPE_PARAMETERS = {
|
||||||
"rope_theta": 1000000,
|
"rope_theta": 1000000,
|
||||||
|
|||||||
@@ -162,14 +162,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
local_total_toks = local_chunked_kv_lens_rank.sum()
|
local_total_toks = local_chunked_kv_lens_rank.sum()
|
||||||
chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks)
|
chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks)
|
||||||
local_chunk_starts = torch.zeros(
|
local_chunk_starts = torch.zeros(
|
||||||
(len(local_context_lens_allranks)), dtype=torch.int32, device=self.device
|
(len(local_context_lens_allranks),), dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
kv_inverse_idx_for_chunk = torch.argsort(
|
||||||
kv_inverse_idx_for_chunk = (
|
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(torch.float32)
|
||||||
torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32))
|
|
||||||
if cp_kv_recover_idx_for_chunk is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
cp_kv_recover_idx_for_chunk = torch.argsort(kv_inverse_idx_for_chunk)
|
||||||
|
|
||||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
||||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||||
|
|||||||
@@ -70,8 +70,6 @@ class AscendPrefillContextParallelMetadata:
|
|||||||
|
|
||||||
pcp_allgather_restore_idx: torch.Tensor = None
|
pcp_allgather_restore_idx: torch.Tensor = None
|
||||||
|
|
||||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
|
||||||
|
|
||||||
num_actual_tokens_pcp_padded: int = 0
|
num_actual_tokens_pcp_padded: int = 0
|
||||||
|
|
||||||
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
|
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
|
||||||
|
|||||||
@@ -555,9 +555,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.num_spec_tokens)
|
self.num_spec_tokens)
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
if not self.vllm_config.model_config.use_mla:
|
|
||||||
self.pcp_manager.generate_kv_idx(scheduler_output,
|
|
||||||
self.input_batch)
|
|
||||||
num_scheduled_tokens[:
|
num_scheduled_tokens[:
|
||||||
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
||||||
num_scheduled_tokens[:num_reqs],
|
num_scheduled_tokens[:num_reqs],
|
||||||
|
|||||||
@@ -86,9 +86,6 @@ class PCPManager:
|
|||||||
)
|
)
|
||||||
self.num_actual_tokens_pcp_padded = 0
|
self.num_actual_tokens_pcp_padded = 0
|
||||||
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
|
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
|
||||||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
|
||||||
[] for _ in range(self.pcp_world_size)
|
|
||||||
]
|
|
||||||
self.full_indices = list(
|
self.full_indices = list(
|
||||||
range(self.max_num_tokens * self.pcp_world_size *
|
range(self.max_num_tokens * self.pcp_world_size *
|
||||||
self.dcp_world_size + self.pcp_world_size *
|
self.dcp_world_size + self.pcp_world_size *
|
||||||
@@ -563,47 +560,6 @@ class PCPManager:
|
|||||||
[-1, pcp_world_size, dcp_world_size])
|
[-1, pcp_world_size, dcp_world_size])
|
||||||
return dcp_local_seq_lens
|
return dcp_local_seq_lens
|
||||||
|
|
||||||
def generate_kv_idx(self, scheduler_output, input_batch):
|
|
||||||
if not self.pcp_world_size > 1:
|
|
||||||
return
|
|
||||||
self.cp_kv_recover_idx_for_chunk = [[]
|
|
||||||
for _ in range(self.pcp_world_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
|
||||||
num_scheduled_token = scheduler_output.num_scheduled_tokens[req_id]
|
|
||||||
is_prefill = num_scheduled_token > self.decode_threshold
|
|
||||||
if is_prefill:
|
|
||||||
num_cp_padded_scheduled_tokens = cdiv(
|
|
||||||
num_scheduled_token,
|
|
||||||
2 * self.pcp_world_size) * (2 * self.pcp_world_size)
|
|
||||||
chunk_size = num_cp_padded_scheduled_tokens // (
|
|
||||||
2 * self.pcp_world_size)
|
|
||||||
num_added_recover_tokens = len(
|
|
||||||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size
|
|
||||||
for rank in range(self.pcp_world_size):
|
|
||||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
|
||||||
self.full_indices[rank * chunk_size +
|
|
||||||
num_added_recover_tokens:(rank + 1) *
|
|
||||||
chunk_size +
|
|
||||||
num_added_recover_tokens])
|
|
||||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
|
||||||
self.full_indices[num_cp_padded_scheduled_tokens -
|
|
||||||
(rank + 1) * chunk_size +
|
|
||||||
num_added_recover_tokens:
|
|
||||||
num_cp_padded_scheduled_tokens -
|
|
||||||
rank * chunk_size +
|
|
||||||
num_added_recover_tokens])
|
|
||||||
|
|
||||||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
|
||||||
np.concatenate(
|
|
||||||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
|
||||||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
|
||||||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
|
||||||
non_blocking=True)
|
|
||||||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
|
||||||
torch.float32).argsort().to(torch.int32)
|
|
||||||
|
|
||||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||||
input_batch, num_scheduled_tokens):
|
input_batch, num_scheduled_tokens):
|
||||||
from vllm_ascend.attention.utils import \
|
from vllm_ascend.attention.utils import \
|
||||||
@@ -774,7 +730,6 @@ class PCPManager:
|
|||||||
}
|
}
|
||||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
|
||||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||||
|
|||||||
Reference in New Issue
Block a user