[V1][eagle3] Support eagle3 proposer for v1 (#1032)
### What this PR does / why we need it? This PR implements the Eagle Pososer feature for vLLM v1, which enables more efficient speculative decoding by using a draft model to predict potential future tokens. - The implementation includes the core Eagle algorithm integration with vLLM's existing architecture, allowing for faster inference while maintaining output quality. - This is needed to significantly improve the generation speed of large language models without compromising on the quality of generated text. ### Does this PR introduce any user-facing change? Yes, this PR introduces a new speculative decoding mode that can be enabled via configuration. - Users can now choose to use Eagle Pososer by setting appropriate flags in the inference configuration. - The API remains backward compatible, with the new functionality being opt-in. ### How was this patch tested? CI passed with new unit tests added for the Eagle Pososer functionality. - Benchmark tests were conducted comparing generation speed and quality with and without Eagle Pososer. - Integration tests were performed with various model architectures to ensure compatibility. - Manual testing was done using different prompt scenarios to verify output quality remains consistent. - we test accept rate on one Ascend 910B npu, The acceptance rate results are basically consistent with those shown here: https://github.com/vllm-project/vllm/pull/16937 - Currently, we support scenarios where num_spec_tokens <= 2. When num_spec_tokens > 2, issues such as insufficient GPU memory and operator computation errors may occur. We will address this in subsequent updates. - We will add support for Eagle v1 in future updates. ### Acceptance Test Script ```bash SCRIPT="/offline/eagle.py" DATASET="ShareGpt" MODEL=Meta-Llama-3.1-8B-Instruct DRAFT=EAGLE3-LLaMA3.1-Instruct-8B CUDA_VISIBLE_DEVICES="0" VLLM_USE_V1=1 $PYTHON $SCRIPT \ --dataset $DATASET \ --num_spec_tokens 2 \ --max_num_seqs 1 \ --model_dir $MODEL \ --eagle_dir $DRAFT \ --tp 1 \ --num_prompts 80 ``` ### Acceptance Test Results ```bash ██████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [21:22<00:00, 16.03s/it, est. speed input: 4.72 toks/s, output: 13.56 toks/s] ------------------------------------------------------------------------------------- mean acceptance length: 1.63 ------------------------------------------------------------------------------------- total_counts: 8062 acceptance at token 0: 1.00 (8062 times) acceptance at token 1: 0.70 (5612 times) acceptance at token 2: 0.47 (3765 times) ``` Closes: https://github.com/vllm-project/vllm-ascend/issues/1004 --------- Signed-off-by: yuancaoyaoHW <a2749322671@gmail.com>
This commit is contained in:
@@ -57,7 +57,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
@@ -70,11 +69,13 @@ from vllm.v1.worker.utils import (gather_mm_placeholders,
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is
|
||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -206,8 +207,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
self.use_spec_decode = False
|
||||
self.spec_attn_mask = None
|
||||
self.use_eagle = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
@@ -217,9 +220,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.method == "eagle":
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
elif self.speculative_config.method in ["eagle", "eagle3"]:
|
||||
self.use_eagle = True
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device,
|
||||
self) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == 'deepseek_mtp':
|
||||
self.drafter = MtpProposer(self.vllm_config, self)
|
||||
else:
|
||||
@@ -589,6 +595,140 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
group=get_dp_group().cpu_group)
|
||||
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
||||
|
||||
def get_eagle_atten_dict(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> dict[str, AscendMetadata]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||
# attetnion layers in the model for now, so the current logic for
|
||||
# getting attn_metadata is not related to kv_cache_group information.
|
||||
# Will extend this part to support multiple KV cache groups later.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the NPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
attn_metadata: dict[str, AscendMetadata] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -776,7 +916,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor]:
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -874,7 +1014,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
if self.use_eagle:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -1047,8 +1190,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
sample_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, sample_indices)
|
||||
total_num_scheduled_tokens, sample_indices, aux_hidden_states)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
cumsum_dtype: Optional[np.dtype] = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the cumulative sum and batched arange of the given array.
|
||||
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
||||
# Equivalent to but faster than:
|
||||
# np.concatenate([np.arange(n) for n in num_tokens])
|
||||
"""
|
||||
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
||||
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
||||
total_num_tokens = cu_num_tokens[-1]
|
||||
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
||||
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
||||
|
||||
return cu_num_tokens, arange
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
@@ -1189,6 +1356,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: SpecDecodeMetadata,
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
) -> Optional[list[list[int]]]:
|
||||
if not self.use_spec_decode:
|
||||
# Speculative decoding is not enabled.
|
||||
@@ -1198,9 +1366,85 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids = self._generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
elif self.speculative_config.method == "eagle":
|
||||
raise NotImplementedError(
|
||||
"eagle method for spec decode doesn't work on vllm-ascend currently"
|
||||
)
|
||||
raise NotImplementedError("Eagle Is Not Supported Yet.")
|
||||
elif self.speculative_config.method == "eagle3":
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
if self.speculative_config.use_eagle():
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (
|
||||
req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_name]
|
||||
num_input_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([
|
||||
h[:num_scheduled_tokens] for h in aux_hidden_states
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:
|
||||
num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
num_tokens = num_scheduled_tokens - sum(
|
||||
num_rejected_tokens)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_rejected_tokens, num_tokens)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
|
||||
positions = self.positions[:num_input_tokens]
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
elif self.speculative_config.method == 'deepseek_mtp':
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
spec_token_ids = self._generate_mtp_token_ids(
|
||||
@@ -1222,14 +1466,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens,
|
||||
sample_indices) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
num_scheduled_tokens, sample_indices,
|
||||
aux_hidden_states) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices],
|
||||
None)
|
||||
|
||||
if self.use_eagle:
|
||||
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
@@ -1268,6 +1514,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
discard_sampled_tokens_req_indices: list[int] = []
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
@@ -1314,6 +1561,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
)
|
||||
if vllm_version_is("0.9.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
@@ -1436,6 +1684,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens: int,
|
||||
is_compile: bool = False,
|
||||
with_prefill: bool = True,
|
||||
skip_attn: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@@ -1450,6 +1699,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
@@ -1515,7 +1774,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, _ = hidden_states
|
||||
else:
|
||||
hidden_states = hidden_states
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# FIXME Profile with multimodal encoder & encoder cache.
|
||||
@@ -1563,7 +1830,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model()
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
self.drafter.load_model(self.model)
|
||||
else:
|
||||
self.drafter.load_model()
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers())
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
@@ -1636,6 +1909,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
self.kv_cache_config = kv_cache_config
|
||||
import torch_npu
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user