[Misc] Code clean up (#1679)

Make model_runner_v1 more readable

- vLLM version: v0.9.2
- vLLM main:
baed180aa0

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-09 14:33:40 +08:00
committed by GitHub
parent 392fd7239b
commit b979ee353d
2 changed files with 138 additions and 254 deletions

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
logger = init_logger(__name__)

View File

@@ -69,10 +69,12 @@ from vllm.v1.worker.utils import (gather_mm_placeholders,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata,
CommonAttentionMetadata)
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.pool.metadata import PoolingMetadata
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -193,10 +195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
self.graph_block_tables = np.zeros(
(self.max_num_reqs,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
# Set up Attention
self.attn_backend = get_attn_backend(
@@ -209,13 +208,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
self.use_spec_decode = False
self.spec_attn_mask = None
self.use_eagle = False
self.drafter = None
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
if self.speculative_config:
self.use_spec_decode = True
self.spec_attn_mask = torch.triu(torch.ones(2048,
@@ -315,19 +318,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# NOTE: Pre-construct a mask matrix to improve the efficiency of
# attention mask construction during inference.
# Note that the length of the matrix needs to be carefully balanced: a
# matrix that is too large will consume excessive VRAM, while a matrix
# that is too small will require dynamic concatenation during inference,
# leading to performance degradation.
# Therefore, an environment variable is added here to dynamically set
# the size of the pre-constructed mask matrix based on requirements.
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
self.attn_mask_builder = AttentionMaskBuilder(attn_mask_len,
self.dtype)
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
@@ -566,7 +556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def get_eagle_atten_dict(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, AscendMetadata]:
) -> dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -677,7 +668,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
attn_metadata: dict[str, AscendMetadata] = {}
attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata]] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -880,7 +872,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -990,11 +983,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
attn_state = AscendAttentionState.PrefillCacheHit
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions,
attn_state=attn_state)
self.attn_mask = attn_mask
self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
extra_builder_kwargs = {}
@@ -1010,10 +1003,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
@@ -1037,6 +1026,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
if self.vllm_config.model_config.use_mla:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
@@ -1326,98 +1319,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: SpecDecodeMetadata,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata],
aux_hidden_states: torch.Tensor = None,
) -> Optional[list[list[int]]]:
if not self.use_spec_decode:
# Speculative decoding is not enabled.
spec_token_ids = None
elif self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self._generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
spec_token_ids = self._generate_ngram_token_ids(
valid_sampled_token_ids)
elif self.speculative_config.method == "eagle":
raise NotImplementedError("Eagle Is Not Supported Yet.")
elif self.speculative_config.method == "eagle3":
assert isinstance(self.drafter, EagleProposer)
if self.speculative_config.use_eagle():
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (
req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_name]
num_input_tokens = scheduler_output.total_num_scheduled_tokens
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([
h[:num_scheduled_tokens] for h in aux_hidden_states
],
dim=-1)
else:
target_hidden_states = hidden_states[:
num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(
num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens, num_tokens)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
positions = self.positions[:num_input_tokens]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
spec_token_ids = self._generate_eagle3_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, aux_hidden_states)
elif self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer)
spec_token_ids = self._generate_mtp_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
@@ -1483,14 +1402,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output, intermediate_tensors))
with ProfileExecuteDuration().capture_async("post process"):
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
logits = self.model.compute_logits(hidden_states[sample_indices],
None)
if self.use_eagle:
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
logits = self.apply_grammar_bitmask(scheduler_output, logits)
@@ -1630,96 +1546,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output
def _profile_multimodal(self) -> None:
# TODO: handle encoder-decoder models once we support them.
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
if (not self.is_multimodal_model
or self.max_num_encoder_input_tokens <= 0
or self.encoder_cache_size <= 0):
return
max_tokens_by_modality_dict = (
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
self.model_config))
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
max_tokens_per_mm_item)
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)[dummy_data_modality]
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)
logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
if not isinstance(dummy_mm_data, MultiModalKwargs):
# TODO: Delete this check once input mapper is fully removed.
raise RuntimeError("Legacy input mapper is not supported in V1")
# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
skip_attn: bool = True,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
@@ -1729,19 +1561,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
if skip_attn:
attn_metadata = None
else:
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@@ -1819,48 +1640,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states, _ = hidden_states
else:
hidden_states = hidden_states
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer)
if self.use_spec_decode and isinstance(
self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens)
return hidden_states
def profile_run(self) -> None:
# FIXME Profile with multimodal encoder & encoder cache.
# current _profile_multimodal() using PyTorch SDPA backend method not
# support for window/full attn to reduce Memcpy operations, so will cause
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
# self._profile_multimodal()
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
num_scheduled_tokens_list[
-1] += self.max_num_tokens % self.max_num_reqs
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
assert len(num_scheduled_tokens_list) == self.max_num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# assert self.lora_manager is not None, "LoRA is not enabled"
# TODO: call maybe_profile_with_lora()
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
output = None
if get_pp_group().is_last_rank:
if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states)
else:
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
num_scheduled_tokens_list = [min_tokens_per_req
] * self.max_num_reqs
num_scheduled_tokens_list[
-1] += self.max_num_tokens % self.max_num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# TODO: need to rum a dummy sampler for generate task
hidden_states = hidden_states[logit_indices]
output = self.model.compute_logits(hidden_states, None)
else:
output = None
NPUPlatform.synchronize()
del hidden_states, output
@@ -1879,8 +1684,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
@@ -1929,10 +1732,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.logits_processing_needs_token_ids = True
if self.drafter:
logger.info("Loading drafter model...")
if self.use_aux_hidden_state_outputs:
self.drafter.load_model(self.model)
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
if isinstance(self.drafter, EagleProposer):
if self.use_aux_hidden_state_outputs:
self.drafter.load_model(self.model)
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
self.drafter.load_model()
if self.lora_config:
@@ -2240,10 +2044,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def _generate_draft_token_ids(
def _generate_ngram_token_ids(
self,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
@@ -2264,7 +2067,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
assert self.drafter is not None
assert isinstance(self.drafter, NgramProposer)
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
if drafter_output is None or len(drafter_output) == 0:
@@ -2273,6 +2076,86 @@ class NPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def _generate_eagle3_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
aux_hidden_states: torch.Tensor = None):
assert isinstance(self.drafter, EagleProposer)
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _generate_mtp_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
@@ -2282,8 +2165,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: SpecDecodeMetadata,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata],
):
assert isinstance(self.drafter, MtpProposer)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
@@ -2321,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32,
device=self.device,
)
assert self.drafter is not None
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
@@ -2330,7 +2214,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
assert self.drafter is not None
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,