diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 48fb779b..94e70e57 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -359,7 +359,8 @@ class EagleProposer(Proposer): actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor(), - slot_mapping=self.runner.slot_mapping, + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping, positions=self.runner.positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index b16bb041..d8333abd 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -83,7 +83,7 @@ class BlockTable: pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, + dtype=torch.int32, device=self.device) try: self.pcp_world_size = get_pcp_group( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1f2ba64c..8fe6df2a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -404,9 +404,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) if self.vllm_config.model_config.use_mla and \ self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: @@ -470,11 +467,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): pin_memory=True) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=True) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -1673,12 +1665,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() - slot_mapping = blk_table.slot_mapping_cpu[:slot_mapping_size] - self.slot_mapping[:slot_mapping_size].copy_( - slot_mapping[:slot_mapping_size], - non_blocking=True, - ) - self.slot_mapping[slot_mapping_size:].fill_(0) + slot_mapping = blk_table.slot_mapping[:slot_mapping_size] + blk_table.slot_mapping[slot_mapping_size:].fill_(0) if self.pcp_size > 1: assert pcp_unpad_mask is not None pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: @@ -1688,9 +1676,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): 0]] pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping[ - pcp_unpad_mask] = self.slot_mapping[:slot_mapping_size] - self.slot_mapping[:long_seq_metadata. - num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping + pcp_unpad_mask] = blk_table.slot_mapping[: + slot_mapping_size] + blk_table.slot_mapping[:long_seq_metadata. + num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( @@ -1704,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn block_table_tensor=blk_table_tensor[:num_reqs], - slot_mapping=self.slot_mapping, + slot_mapping=slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, @@ -2517,6 +2506,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.kv_cache_config.kv_cache_groups): block_table_tensor = self.input_batch.block_table[ kv_cache_group_id].get_device_tensor() + slot_mapping = self.input_batch.block_table[ + kv_cache_group_id].slot_mapping common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor( [0] + self.actual_seq_lengths_q[:num_reqs], @@ -2530,7 +2521,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_actual_tokens=num_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=self.slot_mapping, + slot_mapping=slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask,