[Misc] Code clean up (#1679)
Make model_runner_v1 more readable
- vLLM version: v0.9.2
- vLLM main:
baed180aa0
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -69,10 +69,12 @@ from vllm.v1.worker.utils import (gather_mm_placeholders,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata,
|
||||
CommonAttentionMetadata)
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.pool.metadata import PoolingMetadata
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@@ -193,10 +195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.max_num_reqs,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
|
||||
|
||||
# Set up Attention
|
||||
self.attn_backend = get_attn_backend(
|
||||
@@ -209,13 +208,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
min(self.model_config.max_model_len,
|
||||
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
self.use_spec_decode = False
|
||||
self.spec_attn_mask = None
|
||||
self.use_eagle = False
|
||||
self.drafter = None
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
||||
MtpProposer]] = None
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
@@ -315,19 +318,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||
# attention mask construction during inference.
|
||||
# Note that the length of the matrix needs to be carefully balanced: a
|
||||
# matrix that is too large will consume excessive VRAM, while a matrix
|
||||
# that is too small will require dynamic concatenation during inference,
|
||||
# leading to performance degradation.
|
||||
# Therefore, an environment variable is added here to dynamically set
|
||||
# the size of the pre-constructed mask matrix based on requirements.
|
||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
|
||||
self.attn_mask_builder = AttentionMaskBuilder(attn_mask_len,
|
||||
self.dtype)
|
||||
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
@@ -566,7 +556,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_eagle_atten_dict(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> dict[str, AscendMetadata]:
|
||||
) -> dict[str, Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -677,7 +668,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
attn_metadata: dict[str, AscendMetadata] = {}
|
||||
attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata]] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@@ -880,7 +872,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@@ -990,11 +983,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions,
|
||||
attn_state=attn_state)
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_mask = self._make_attention_mask(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions,
|
||||
attn_state=attn_state)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
extra_builder_kwargs = {}
|
||||
@@ -1010,10 +1003,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
@@ -1037,6 +1026,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
@@ -1326,98 +1319,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: SpecDecodeMetadata,
|
||||
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata],
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
) -> Optional[list[list[int]]]:
|
||||
if not self.use_spec_decode:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
elif self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self._generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
spec_token_ids = self._generate_ngram_token_ids(
|
||||
valid_sampled_token_ids)
|
||||
elif self.speculative_config.method == "eagle":
|
||||
raise NotImplementedError("Eagle Is Not Supported Yet.")
|
||||
elif self.speculative_config.method == "eagle3":
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
if self.speculative_config.use_eagle():
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (
|
||||
req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_name]
|
||||
num_input_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([
|
||||
h[:num_scheduled_tokens] for h in aux_hidden_states
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:
|
||||
num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
num_tokens = num_scheduled_tokens - sum(
|
||||
num_rejected_tokens)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_rejected_tokens, num_tokens)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
|
||||
positions = self.positions[:num_input_tokens]
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
spec_token_ids = self._generate_eagle3_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||
hidden_states, aux_hidden_states)
|
||||
elif self.speculative_config.method == 'deepseek_mtp':
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
spec_token_ids = self._generate_mtp_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||
@@ -1483,14 +1402,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices],
|
||||
None)
|
||||
if self.use_eagle:
|
||||
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
@@ -1630,96 +1546,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def _profile_multimodal(self) -> None:
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
# NOTE: Currently model is profiled with a single non-text
|
||||
# modality with the max possible input tokens even when
|
||||
# it supports multiple.
|
||||
|
||||
if (not self.is_multimodal_model
|
||||
or self.max_num_encoder_input_tokens <= 0
|
||||
or self.encoder_cache_size <= 0):
|
||||
return
|
||||
|
||||
max_tokens_by_modality_dict = (
|
||||
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
|
||||
self.model_config))
|
||||
dummy_data_modality, max_tokens_per_mm_item = max(
|
||||
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
encoder_budget = min(self.max_num_encoder_input_tokens,
|
||||
self.encoder_cache_size)
|
||||
|
||||
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
||||
max_tokens_per_mm_item)
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
|
||||
self.model_config)[dummy_data_modality]
|
||||
|
||||
# NOTE: We do not consider max_num_batched_tokens on purpose
|
||||
# because the multimodal embeddings can be generated in advance
|
||||
# and chunked prefilled.
|
||||
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
||||
max_mm_items_per_req
|
||||
|
||||
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
||||
max_num_mm_items_decoder_budget)
|
||||
|
||||
logger.info(
|
||||
"Encoder cache will be initialized with a budget of %s tokens,"
|
||||
" and profiled with %s %s items of the maximum feature size.",
|
||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
||||
|
||||
# Create dummy batch of multimodal inputs.
|
||||
dummy_request_data = self.input_registry.dummy_data_for_profiling(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
mm_registry=self.mm_registry,
|
||||
)
|
||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||
|
||||
if not isinstance(dummy_mm_data, MultiModalKwargs):
|
||||
# TODO: Delete this check once input mapper is fully removed.
|
||||
raise RuntimeError("Legacy input mapper is not supported in V1")
|
||||
|
||||
# Dummy data definition in V0 may contain multiple multimodal items
|
||||
# (e.g, multiple images) for a single request, therefore here we
|
||||
# always replicate first item by max_num_mm_items times since in V1
|
||||
# they are scheduled to be processed separately.
|
||||
|
||||
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
|
||||
item_index=0)
|
||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||
max_num_mm_items)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs, device=self.device)
|
||||
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
||||
"Expected dimension 0 of encoder outputs to match the number "
|
||||
f"of multimodal data items: {max_num_mm_items}, got "
|
||||
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
||||
"due to the 'get_multimodal_embeddings' method of the model "
|
||||
"not implemented correctly.")
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
is_compile: bool = False,
|
||||
with_prefill: bool = True,
|
||||
skip_attn: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@@ -1729,19 +1561,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
@@ -1819,48 +1640,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, _ = hidden_states
|
||||
else:
|
||||
hidden_states = hidden_states
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
if self.use_spec_decode and isinstance(
|
||||
self.drafter, EagleProposer):
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# FIXME Profile with multimodal encoder & encoder cache.
|
||||
# current _profile_multimodal() using PyTorch SDPA backend method not
|
||||
# support for window/full attn to reduce Memcpy operations, so will cause
|
||||
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
|
||||
# self._profile_multimodal()
|
||||
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||||
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
|
||||
num_scheduled_tokens_list[
|
||||
-1] += self.max_num_tokens % self.max_num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
|
||||
assert len(num_scheduled_tokens_list) == self.max_num_reqs
|
||||
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
|
||||
# assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
# TODO: call maybe_profile_with_lora()
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
|
||||
output = None
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.is_pooling_model:
|
||||
output = self._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req
|
||||
] * self.max_num_reqs
|
||||
num_scheduled_tokens_list[
|
||||
-1] += self.max_num_tokens % self.max_num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
# TODO: need to rum a dummy sampler for generate task
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
output = self.model.compute_logits(hidden_states, None)
|
||||
else:
|
||||
output = None
|
||||
|
||||
NPUPlatform.synchronize()
|
||||
del hidden_states, output
|
||||
@@ -1879,8 +1684,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
@@ -1929,10 +1732,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
if self.drafter:
|
||||
logger.info("Loading drafter model...")
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
self.drafter.load_model(self.model)
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers())
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
self.drafter.load_model(self.model)
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers())
|
||||
else:
|
||||
self.drafter.load_model()
|
||||
if self.lora_config:
|
||||
@@ -2240,10 +2044,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, npu_graph_size / (1 << 30))
|
||||
|
||||
def _generate_draft_token_ids(
|
||||
def _generate_ngram_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
draft_token_ids: list[list[int]] = []
|
||||
@@ -2264,7 +2067,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
assert self.drafter is not None
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx])
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
@@ -2273,6 +2076,86 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def _generate_eagle3_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
|
||||
num_tokens)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def _generate_mtp_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
@@ -2282,8 +2165,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: SpecDecodeMetadata,
|
||||
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata],
|
||||
):
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
@@ -2321,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
assert self.drafter is not None
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
@@ -2330,7 +2214,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
assert self.drafter is not None
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
|
||||
Reference in New Issue
Block a user