diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index d2fda0b..fc074d5 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -14,7 +14,7 @@ from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.sample.metadata import SamplingMetadata -from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState logger = init_logger(__name__) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 38e1e33..c77fce0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,10 +69,12 @@ from vllm.v1.worker.utils import (gather_mm_placeholders, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata +from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata, + CommonAttentionMetadata) from vllm_ascend.platform import NPUPlatform from vllm_ascend.pool.metadata import PoolingMetadata from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -193,10 +195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=self.device) self.graph_block_tables = np.zeros( - (self.max_num_reqs, - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) + (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32) # Set up Attention self.attn_backend = get_attn_backend( @@ -209,13 +208,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + self.attn_mask_builder = AttentionMaskBuilder( + min(self.model_config.max_model_len, + int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) # Set up speculative decoding. self.use_aux_hidden_state_outputs = False self.use_spec_decode = False self.spec_attn_mask = None self.use_eagle = False - self.drafter = None + self.drafter: Optional[Union[NgramProposer, EagleProposer, + MtpProposer]] = None if self.speculative_config: self.use_spec_decode = True self.spec_attn_mask = torch.triu(torch.ones(2048, @@ -315,19 +318,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) - # NOTE: Pre-construct a mask matrix to improve the efficiency of - # attention mask construction during inference. - # Note that the length of the matrix needs to be carefully balanced: a - # matrix that is too large will consume excessive VRAM, while a matrix - # that is too small will require dynamic concatenation during inference, - # leading to performance degradation. - # Therefore, an environment variable is added here to dynamically set - # the size of the pre-constructed mask matrix based on requirements. - mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) - attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder(attn_mask_len, - self.dtype) - self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore @@ -566,7 +556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): def get_eagle_atten_dict( self, scheduler_output: "SchedulerOutput", - ) -> dict[str, AscendMetadata]: + ) -> dict[str, Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -677,7 +668,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - attn_metadata: dict[str, AscendMetadata] = {} + attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata]] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -880,7 +872,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, + ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -990,11 +983,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: attn_state = AscendAttentionState.PrefillCacheHit - attn_mask = self._make_attention_mask(seq_lens=seq_lens, - query_lens=num_scheduled_tokens, - position=positions, - attn_state=attn_state) - self.attn_mask = attn_mask + self.attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=positions, + attn_state=attn_state) self.attn_state = attn_state # type: ignore extra_builder_kwargs = {} @@ -1010,10 +1003,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] @@ -1037,6 +1026,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) attn_metadata = self.attn_metadata_builder.build( # type: ignore num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, @@ -1326,98 +1319,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.use_spec_decode: # Speculative decoding is not enabled. spec_token_ids = None elif self.speculative_config.method == "ngram": - assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self._generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self._generate_ngram_token_ids( + valid_sampled_token_ids) elif self.speculative_config.method == "eagle": raise NotImplementedError("Eagle Is Not Supported Yet.") elif self.speculative_config.method == "eagle3": - assert isinstance(self.drafter, EagleProposer) - if self.speculative_config.use_eagle(): - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = ( - req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - eagle_attn_metadata = attn_metadata[ - self.drafter.attn_layer_name] - num_input_tokens = scheduler_output.total_num_scheduled_tokens - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([ - h[:num_scheduled_tokens] for h in aux_hidden_states - ], - dim=-1) - else: - target_hidden_states = hidden_states[: - num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc - else: - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - num_tokens = num_scheduled_tokens - sum( - num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, - num_rejected_tokens, num_tokens) - target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] - - positions = self.positions[:num_input_tokens] - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - ) - spec_token_ids = draft_token_ids.tolist() + spec_token_ids = self._generate_eagle3_token_ids( + valid_sampled_token_ids, sampling_metadata, scheduler_output, + spec_decode_metadata, positions, num_scheduled_tokens, + hidden_states, aux_hidden_states) elif self.speculative_config.method == 'deepseek_mtp': - assert isinstance(self.drafter, MtpProposer) spec_token_ids = self._generate_mtp_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, @@ -1483,14 +1402,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output, intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): - if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, num_scheduled_tokens_np) logits = self.model.compute_logits(hidden_states[sample_indices], None) - if self.use_eagle: - attn_metadata = self.get_eagle_atten_dict(scheduler_output) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: logits = self.apply_grammar_bitmask(scheduler_output, logits) @@ -1630,96 +1546,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): return model_runner_output - def _profile_multimodal(self) -> None: - # TODO: handle encoder-decoder models once we support them. - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - - if (not self.is_multimodal_model - or self.max_num_encoder_input_tokens <= 0 - or self.encoder_cache_size <= 0): - return - - max_tokens_by_modality_dict = ( - MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( - self.model_config)) - dummy_data_modality, max_tokens_per_mm_item = max( - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) - - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens_per_mm_item) - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( - self.model_config)[dummy_data_modality] - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) - - logger.info( - "Encoder cache will be initialized with a budget of %s tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) - - # Create dummy batch of multimodal inputs. - dummy_request_data = self.input_registry.dummy_data_for_profiling( - model_config=self.model_config, - seq_len=self.max_num_tokens, - mm_registry=self.mm_registry, - ) - dummy_mm_data = dummy_request_data.multi_modal_data - - if not isinstance(dummy_mm_data, MultiModalKwargs): - # TODO: Delete this check once input mapper is fully removed. - raise RuntimeError("Legacy input mapper is not supported in V1") - - # Dummy data definition in V0 may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - - dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality, - item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) - - batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - max_num_mm_items) - batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) - - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - @torch.inference_mode() def _dummy_run( self, num_tokens: int, is_compile: bool = False, with_prefill: bool = True, - skip_attn: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1729,19 +1561,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - if skip_attn: - attn_metadata = None - else: - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1819,48 +1640,32 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states, _ = hidden_states else: hidden_states = hidden_states - if self.use_spec_decode and \ - self.speculative_config.method in ('eagle', 'eagle3'): - assert isinstance(self.drafter, EagleProposer) + if self.use_spec_decode and isinstance( + self.drafter, EagleProposer): self.drafter.dummy_run(num_tokens) return hidden_states def profile_run(self) -> None: - # FIXME Profile with multimodal encoder & encoder cache. - # current _profile_multimodal() using PyTorch SDPA backend method not - # support for window/full attn to reduce Memcpy operations, so will cause - # Out Of Memory problem, so we currently don't use self._profile_multimodal() - # self._profile_multimodal() - - # For profile, have maximum num_reqs and that collectively have - # maximum num_tokens. - min_tokens_per_req = self.max_num_tokens // self.max_num_reqs - - num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs - num_scheduled_tokens_list[ - -1] += self.max_num_tokens % self.max_num_reqs - assert sum(num_scheduled_tokens_list) == self.max_num_tokens - assert len(num_scheduled_tokens_list) == self.max_num_reqs - - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) - logit_indices = np.cumsum(num_scheduled_tokens) - 1 - - # assert self.lora_manager is not None, "LoRA is not enabled" - # TODO: call maybe_profile_with_lora() - # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens) - + output = None if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) else: + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + min_tokens_per_req = self.max_num_tokens // self.max_num_reqs + num_scheduled_tokens_list = [min_tokens_per_req + ] * self.max_num_reqs + num_scheduled_tokens_list[ + -1] += self.max_num_tokens % self.max_num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task hidden_states = hidden_states[logit_indices] output = self.model.compute_logits(hidden_states, None) - else: - output = None NPUPlatform.synchronize() del hidden_states, output @@ -1879,8 +1684,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs hidden_states_list = list( torch.split(hidden_states, num_scheduled_tokens_list)) @@ -1929,10 +1732,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.logits_processing_needs_token_ids = True if self.drafter: logger.info("Loading drafter model...") - if self.use_aux_hidden_state_outputs: - self.drafter.load_model(self.model) - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + if isinstance(self.drafter, EagleProposer): + if self.use_aux_hidden_state_outputs: + self.drafter.load_model(self.model) + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) else: self.drafter.load_model() if self.lora_config: @@ -2240,10 +2044,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _generate_draft_token_ids( + def _generate_ngram_token_ids( self, sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] @@ -2264,7 +2067,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - assert self.drafter is not None + assert isinstance(self.drafter, NgramProposer) drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx]) if drafter_output is None or len(drafter_output) == 0: @@ -2273,6 +2076,86 @@ class NPUModelRunner(LoRAModelRunnerMixin): draft_token_ids.append(drafter_output.tolist()) return draft_token_ids + def _generate_eagle3_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + scheduler_output: "SchedulerOutput", + spec_decode_metadata: SpecDecodeMetadata, + positions: torch.Tensor, + num_scheduled_tokens: int, + hidden_states: torch.Tensor, + aux_hidden_states: torch.Tensor = None): + assert isinstance(self.drafter, EagleProposer) + attn_metadata = self.get_eagle_atten_dict(scheduler_output) + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + def _generate_mtp_token_ids( self, valid_sampled_token_ids: list[list[int]], @@ -2282,8 +2165,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: Union[AscendMetadata, AscendMLAMetadata, + AscendTorchairMetadata], ): + assert isinstance(self.drafter, MtpProposer) next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): if token_ids: @@ -2321,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32, device=self.device, ) - assert self.drafter is not None cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, @@ -2330,7 +2214,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - assert self.drafter is not None + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions,