[Bugfix] Fix zero attention output in qwen3-next (#3572)
### What this PR does / why we need it? Since Attention and LinearAttention share the same ```slot_mapping```, and the ```slot_mapping``` for LinearAttention is all zeros, the ```slot_mapping``` for Attention gets overwritten, resulting in the computed output being all zeros. This PR removes the uniformly managed ```self.slot_mapping``` and directly passes the ```slot_mapping``` from ```input_batch.blocktable``` to ```attn_metadata```, along with modifying the relevant references. Due to hardware, the data type of ```block_table.slot_mapping``` needs to be set to int32. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: QilaiZhang <245706640@qq.com>
This commit is contained in:
@@ -359,7 +359,8 @@ class EagleProposer(Proposer):
|
|||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||||
get_device_tensor(),
|
get_device_tensor(),
|
||||||
slot_mapping=self.runner.slot_mapping,
|
slot_mapping=self.runner.input_batch.block_table[0].
|
||||||
|
slot_mapping,
|
||||||
positions=self.runner.positions,
|
positions=self.runner.positions,
|
||||||
attn_mask=self.runner.attn_mask,
|
attn_mask=self.runner.attn_mask,
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class BlockTable:
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||||
dtype=torch.int64,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
try:
|
try:
|
||||||
self.pcp_world_size = get_pcp_group(
|
self.pcp_world_size = get_pcp_group(
|
||||||
|
|||||||
@@ -404,9 +404,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla and \
|
if self.vllm_config.model_config.use_mla and \
|
||||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
@@ -470,11 +467,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.positions_np = self.positions_cpu.numpy()
|
self.positions_np = self.positions_cpu.numpy()
|
||||||
|
|
||||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=True)
|
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
||||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@@ -1673,12 +1665,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||||
blk_table_tensor = blk_table.get_device_tensor()
|
blk_table_tensor = blk_table.get_device_tensor()
|
||||||
slot_mapping = blk_table.slot_mapping_cpu[:slot_mapping_size]
|
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
||||||
self.slot_mapping[:slot_mapping_size].copy_(
|
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
||||||
slot_mapping[:slot_mapping_size],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
self.slot_mapping[slot_mapping_size:].fill_(0)
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
assert pcp_unpad_mask is not None
|
assert pcp_unpad_mask is not None
|
||||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||||||
@@ -1688,8 +1676,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
0]]
|
0]]
|
||||||
pcp_padded_slot_mapping.fill_(-1)
|
pcp_padded_slot_mapping.fill_(-1)
|
||||||
pcp_padded_slot_mapping[
|
pcp_padded_slot_mapping[
|
||||||
pcp_unpad_mask] = self.slot_mapping[:slot_mapping_size]
|
pcp_unpad_mask] = blk_table.slot_mapping[:
|
||||||
self.slot_mapping[:long_seq_metadata.
|
slot_mapping_size]
|
||||||
|
blk_table.slot_mapping[:long_seq_metadata.
|
||||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||||
|
|
||||||
# Make AscendCommonAttentionMetadata
|
# Make AscendCommonAttentionMetadata
|
||||||
@@ -1704,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||||
# TODO: change this to the right block table for linear attn
|
# TODO: change this to the right block table for linear attn
|
||||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
@@ -2517,6 +2506,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
block_table_tensor = self.input_batch.block_table[
|
block_table_tensor = self.input_batch.block_table[
|
||||||
kv_cache_group_id].get_device_tensor()
|
kv_cache_group_id].get_device_tensor()
|
||||||
|
slot_mapping = self.input_batch.block_table[
|
||||||
|
kv_cache_group_id].slot_mapping
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=torch.tensor(
|
query_start_loc=torch.tensor(
|
||||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||||
@@ -2530,7 +2521,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||||
block_table_tensor=block_table_tensor[:num_reqs],
|
block_table_tensor=block_table_tensor[:num_reqs],
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user