[Feat] Support async_scheduler and disable_padded_drafter_batch in eagle (#4893)

### What this PR does / why we need it?
We refactored the eagle_proposer.py to adapt the framework of eagle.py
in vllm-v0.12.0, to support the logit of padded drafter batch and
async-scheduler.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
anon189Ty
2025-12-16 22:06:40 +08:00
committed by GitHub
parent cee521bad5
commit 5b1da4e914
6 changed files with 577 additions and 403 deletions

View File

@@ -7,9 +7,10 @@ import random
from typing import Any
import pytest
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -115,41 +116,67 @@ def test_eagle_correctness(
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
pytest.skip("To be aligned with GPU")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
# NOTE: e2e of eagle has many problems before.
# We first check whether it is functioning properly.
# Should fix the e2e with VllmRunner in future.
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=False,
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompts = [{
"role": "user",
"content": "Hello, my name is"
}, {
"role": "user",
"content": "The president of the United States is"
}, {
"role": "user",
"content": "The capital of France is"
}, {
"role": "user",
"content": "The future of AI is"
}]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
sampling_params = SamplingParams(
max_tokens=300,
temperature=0.0,
ignore_eos=False,
)
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
# Create an LLM.
llm = LLM(
model=model_name,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=True,
compilation_config={
"level": 3,
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_num_of_warmups": 1,
"cudagraph_capture_sizes": [12],
},
speculative_config={
"disable_padded_drafter_batch": False,
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
"draft_vocab_size": 128256,
},
)
llm.generate(prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
@pytest.mark.skip(

View File

@@ -26,6 +26,13 @@ class TestEagleProposerInitialization(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.mock_cpugpubuffer = patch(
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
def tearDown(self):
self.mock_cpugpubuffer.stop()
def test_initialization_eagle(self):
self.vllm_config.speculative_config.method = "eagle"
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
@@ -44,7 +51,7 @@ class TestEagleProposerInitialization(TestBase):
self.assertEqual(proposer.input_ids.shape, (1024, ))
self.assertEqual(proposer.positions.shape, (1024, ))
self.assertEqual(proposer.hidden_states.shape, (1024, 4096))
self.assertEqual(proposer.arange.shape, (33, ))
self.assertEqual(proposer.arange.shape, (1024, ))
def test_initialization_eagle3(self):
self.vllm_config.speculative_config.method = "eagle3"
@@ -77,10 +84,16 @@ class TestEagleProposerLoadModel(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.mock_cpugpubuffer = patch(
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
def tearDown(self):
self.mock_cpugpubuffer.stop()
@patch(
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
@@ -172,11 +185,17 @@ class TestEagleProposerDummyRun(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.mock_cpugpubuffer = patch(
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
self.proposer.model = MagicMock()
def tearDown(self):
self.mock_cpugpubuffer.stop()
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_basic(self, mock_context):
num_tokens = 32
@@ -216,6 +235,9 @@ class TestEagleProposerGenerateTokenIds(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.mock_cpugpubuffer = patch(
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
@@ -223,7 +245,12 @@ class TestEagleProposerGenerateTokenIds(TestBase):
self.proposer._propose = MagicMock(
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
def test_generate_token_ids_without_metadata(self):
def tearDown(self):
self.mock_cpugpubuffer.stop()
# TODO: This is equivalent to disable_padded_drafter_batch=True.
# We need to add some cases about disable_padded_drafter_batch=False in future.
def test_generate_token_ids(self):
valid_sampled = [[20, 30, 40]]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = [2, 1, 3]
@@ -239,7 +266,7 @@ class TestEagleProposerGenerateTokenIds(TestBase):
return_value={"layer_0": mock_attn_metadata})
result = self.proposer.generate_token_ids(
valid_sampled_token_ids=valid_sampled,
sampled_token_ids=valid_sampled,
scheduler_output=scheduler_output,
positions=positions,
num_scheduled_tokens=num_scheduled,
@@ -247,36 +274,13 @@ class TestEagleProposerGenerateTokenIds(TestBase):
)
self.proposer._propose.assert_called_once()
self.assertEqual(result, [[1, 2], [3, 4], [5, 6]])
def test_generate_token_ids_with_metadata(self):
valid_sampled = [[5], [6, 7], [8, 9, 10]]
spec_metadata = MagicMock()
spec_metadata.num_draft_tokens = [2, 3, 4]
mock_attn_metadata = MagicMock()
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6])
mock_attn_metadata.block_tables = MagicMock()
self.proposer._get_eagle_atten_dict = MagicMock(
return_value={"layer_0": mock_attn_metadata})
self.proposer._prepare_inputs = MagicMock(
return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5])))
result = self.proposer.generate_token_ids(
valid_sampled_token_ids=valid_sampled,
spec_decode_metadata=spec_metadata,
positions=torch.randn(6, 1),
hidden_states=torch.randn(6, 4096),
)
self.proposer._prepare_inputs.assert_called_once()
self.assertEqual(self.proposer._propose.call_count, 1)
self.assertEqual(len(result), 3)
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
class TestEagleProposerHelperMethods(TestBase):
# TODO: Can add some tests about prepare_next_token_ids in future.
def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3)
@@ -293,21 +297,29 @@ class TestEagleProposerHelperMethods(TestBase):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.mock_cpugpubuffer = patch(
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
self.mock_cpugpubuffer.start()
self.proposer = EagleProposer(vllm_config=self.vllm_config,
device=self.device,
runner=self.runner)
def tearDown(self):
self.mock_cpugpubuffer.stop()
# TODO: This is equivalent to disable_padded_drafter_batch=True.
# We need to add a test_prepare_inputs_padded in future.
def test_prepare_inputs(self):
self.proposer.token_arange_np = np.arange(10)
mock_attn = MagicMock()
mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
num_rejected = torch.tensor([1, 0, 1], device=self.device)
mock_return_attn = MagicMock()
with patch.object(self.proposer,
'_prepare_inputs',
return_value=(torch.tensor([0, 2, 5]),
'prepare_inputs',
return_value=(mock_return_attn,
torch.tensor([1, 2, 4]))):
cu_num_tokens, indices = self.proposer._prepare_inputs(
return_attn, indices = self.proposer.prepare_inputs(
mock_attn, num_rejected)
self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5])
self.assertEqual(indices.tolist(), [1, 2, 4])

View File

@@ -730,6 +730,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
if get_ascend_device_type() == AscendDeviceType._910_95:
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
# If it's necessary, the slots should be sliced.
torch_npu.npu_scatter_pa_kv_cache(
key=key[:attn_metadata.num_actual_tokens],
value=value[:attn_metadata.num_actual_tokens].contiguous(),
@@ -742,7 +745,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
value=value[:attn_metadata.num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
slot_indices=slots[:attn_metadata.num_actual_tokens])
return key, value
def forward_impl(

View File

@@ -119,6 +119,35 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
# This only use to eagle now. It will be use to enforce_eager in future.
return AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.
num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
decode_token_per_req=self.decode_token_per_req,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=self.is_only_prefill,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,
prefill_context_parallel_metadata=self.
prefill_context_parallel_metadata,
)
def filter_chunked_req_indices(
seq_len: torch.Tensor,

View File

@@ -14,19 +14,25 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
PADDING_SLOT_ID = -1
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
class EagleProposer(Proposer):
@@ -54,6 +60,19 @@ class EagleProposer(Proposer):
sorted(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
max_batch_size = vllm_config.scheduler_config.max_num_seqs
# Currently we do not use pcp. This is used to adapt the pcp branch.
self.pcp_size = 0
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.decode_threshold = 1 + \
self.vllm_config.speculative_config.num_speculative_tokens
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
@@ -71,12 +90,13 @@ class EagleProposer(Proposer):
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.arange_cpu = torch.arange(max_num_slots_for_arange,
device="cpu",
dtype=torch.int32)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
def load_model(self, model: nn.Module) -> None:
@@ -135,8 +155,7 @@ class EagleProposer(Proposer):
dummy_compute_logits(self.hidden_states)
def generate_token_ids(self,
valid_sampled_token_ids: torch.Tensor
| list[list[int]],
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
@@ -144,273 +163,155 @@ class EagleProposer(Proposer):
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.runner.input_batch.req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[self.attn_layer_name]
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.name == SpecDcodeType.EAGLE3:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices =\
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.name == SpecDcodeType.EAGLE3:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _get_eagle_atten_dict(
self,
scheduler_output: "SchedulerOutput",
):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.runner.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.runner.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
req_ids = self.runner.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
self.runner.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.runner.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = self.runner.positions.np[:total_num_scheduled_tokens]
np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.runner.uses_mrope:
self.runner._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (
positions_np +
req_indices * self.runner.input_batch.token_ids_cpu.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(
self.runner.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.runner.input_ids.cpu[:total_num_scheduled_tokens])
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.runner.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table = self.runner.input_batch.block_table[
kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping.np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.runner.query_start_loc.np[0] = 0
self.runner.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
self.runner.seq_lens.np[:num_reqs] = (
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Copy the tensors to the NPU.
self.runner.input_ids.gpu[:total_num_scheduled_tokens].copy_(
self.runner.input_ids.cpu[:total_num_scheduled_tokens],
non_blocking=True)
if self.runner.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.runner.mrope_positions.gpu[:, :total_num_scheduled_tokens] \
.copy_(
self.runner.
mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.runner.positions.gpu[:total_num_scheduled_tokens].copy_(
self.runner.positions.cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.runner.query_start_loc.gpu[:num_reqs + 1].copy_(
self.runner.query_start_loc.cpu[:num_reqs + 1], non_blocking=True)
self.runner.seq_lens.gpu[:num_reqs].copy_(
self.runner.seq_lens.cpu[:num_reqs], non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.runner.seq_lens.gpu[num_reqs:].fill_(0)
self.runner.query_start_loc.gpu[num_reqs + 1:].fill_(-1)
attn_metadata = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.runner.kv_cache_config.kv_cache_groups):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1],
query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs +
1],
seq_lens_cpu=self.runner.seq_lens.cpu,
num_reqs=num_reqs,
max_query_len=max_num_scheduled_tokens,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping.gpu,
positions=self.runner.positions.gpu,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_i = builder.build(0, common_attn_metadata,
self.runner.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
cumsum_dtype: Optional[np.dtype] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange
return draft_token_ids
def _propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None,
req_scheduled_tokens=None,
long_seq_metadata=None,
num_prefill_reqs=0,
num_decode_reqs=0,
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
device = cu_num_tokens.device
cu_num_tokens = cu_num_tokens.cpu()
block_table = block_table.cpu()
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
target_positions = target_positions.cpu()
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.name == SpecDcodeType.EAGLE3:
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
@@ -423,34 +324,7 @@ class EagleProposer(Proposer):
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
seq_lens = (target_positions[last_token_indices] + 1).int()
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
attn_mask = self.runner.attn_mask
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens.to(device),
query_start_loc_cpu=cu_num_tokens,
seq_lens_cpu=seq_lens.cpu(),
max_query_len=max_query_len,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -458,9 +332,14 @@ class EagleProposer(Proposer):
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions.to(device)
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
@@ -482,12 +361,14 @@ class EagleProposer(Proposer):
draft_token_ids_tensor = torch.zeros(
(self.vllm_config.speculative_config.num_speculative_tokens,
*draft_token_ids.shape),
dtype=draft_token_ids.dtype)
dtype=draft_token_ids.dtype,
device=self.device)
draft_token_ids_tensor[0] = draft_token_ids
positions_cpu = target_positions[last_token_indices].cpu().to(
torch.int64)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
last_token_indices = self.arange[:batch_size]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
@@ -496,16 +377,14 @@ class EagleProposer(Proposer):
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
query_lens.fill_(1)
attn_metadata.query_lens = query_lens
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
attn_metadata.seq_lens_list = seq_lens.tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
for now_speculative in range(
self.vllm_config.speculative_config.num_speculative_tokens -
@@ -513,8 +392,8 @@ class EagleProposer(Proposer):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_tensor[now_speculative].to(device)
positions_cpu += 1
input_ids = draft_token_ids_tensor[now_speculative]
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
@@ -522,16 +401,15 @@ class EagleProposer(Proposer):
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions_cpu = torch.where(exceeds_max_model_len, 0,
positions_cpu)
clamped_positions = clamped_positions_cpu.to(device)
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# TODO: Increment the sequence lengths.
attn_metadata.seq_lens += 1
attn_metadata.seq_lens = attn_metadata.seq_lens + 1
attn_metadata.seq_lens_list = [
_ + 1 for _ in attn_metadata.seq_lens_list
]
@@ -542,22 +420,22 @@ class EagleProposer(Proposer):
# TODO: sequence length to 1 to minimize their overheads in attention.
# Compute the slot mapping.
block_numbers = (clamped_positions_cpu // self.block_size)
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_numbers = (clamped_positions // self.block_size)
block_ids = attn_metadata.block_tables.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
slot_mapping_cpu = (
slot_mapping_tmp = (
block_ids * self.vllm_config.cache_config.block_size +
clamped_positions_cpu % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping_cpu.masked_fill_(exceeds_max_model_len,
slot_mapping_tmp.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# NOTE: ASCEND slot_mapping must on cpu
attn_metadata.slot_mapping = slot_mapping_cpu.to(
torch.int32).to(device)
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
slot_mapping_tmp.to(torch.int32))
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
@@ -565,7 +443,6 @@ class EagleProposer(Proposer):
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device)
# Run the model.
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
@@ -581,49 +458,188 @@ class EagleProposer(Proposer):
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu()
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids
# [batch_size, num_speculative_tokens]
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
return draft_token_ids
def _prepare_inputs(
def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]
return attn_metadata
def prepare_next_token_ids_cpu(
self,
eagle_attn_metadata: AscendMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor,
sampled_token_ids: list[list[int]],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
) -> torch.Tensor:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = requests[req_id]
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
req_id]
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.input_ids.device)
return next_token_ids
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for the spec decode.
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = discard_request_indices[:
num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1)
# Generate a mask for all valid tokens within those requests
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
)
return next_token_ids, valid_sampled_tokens_count
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
cu_target_query_lens = eagle_attn_metadata.query_start_loc
device = eagle_attn_metadata.query_start_loc.device
query_start_loc_cpu = cu_target_query_lens.to("cpu")
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_actual_reqs = len(num_draft_tokens)
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens,
dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_actual_reqs
+ 1]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
new_query_len_per_req = query_start_loc_cpu[
1:] - query_start_loc_cpu[:-1]
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
@@ -631,7 +647,8 @@ class EagleProposer(Proposer):
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
)
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
@@ -646,8 +663,8 @@ class EagleProposer(Proposer):
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
token_offests = (self.token_arange_np[:total_num_tokens] -
new_query_start_locs_expanded)
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
@@ -656,21 +673,101 @@ class EagleProposer(Proposer):
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
# need use npu
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
num_tokens_per_req = query_len_per_req - num_rejected_tokens
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
common_attn_metadata.slot_mapping[token_indices])
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device,
non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
positions=common_attn_metadata.positions[token_indices],
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
return spec_common_attn_metadata, token_indices
return cu_num_tokens, token_indices
def prepare_inputs_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1],
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu),
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = query_start_loc_cpu[
1:] - query_start_loc_cpu[:-1]
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
# so we do not need to fixed them. But if they are used in the future,
# we should fixed them.
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
max_query_len=new_query_len_per_req.max().item(),
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
positions=common_attn_metadata.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
seq_lens=common_attn_metadata.seq_lens)
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
1 - num_rejected_tokens_gpu)
return spec_common_attn_metadata, token_indices, token_indices_to_sample

View File

@@ -801,7 +801,8 @@ class NPUModelRunner(GPUModelRunner):
self.requests[r].num_tokens for r in self.input_batch.req_ids
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
base_num_reqs = self.input_batch.num_reqs
num_reqs = base_num_reqs
if self.pcp_size > 1:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
@@ -1106,6 +1107,11 @@ class NPUModelRunner(GPUModelRunner):
if self.speculative_config and \
self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata
if self.speculative_config.method in ("eagle", "eagle3") and \
self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.spec_decode_common_attn_metadata = \
self.spec_decode_common_attn_metadata.unpadded(
total_num_scheduled_tokens, base_num_reqs)
for attn_group in self.attn_groups[kv_cache_group_id]:
common_prefix_len = 0
@@ -1591,7 +1597,7 @@ class NPUModelRunner(GPUModelRunner):
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.method == "mtp" and \
self.speculative_config.use_eagle() and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens