[Feat] Support async_scheduler and disable_padded_drafter_batch in eagle (#4893)
### What this PR does / why we need it?
We refactored the eagle_proposer.py to adapt the framework of eagle.py
in vllm-v0.12.0, to support the logit of padded drafter batch and
async-scheduler.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -7,9 +7,10 @@ import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
@@ -115,41 +116,67 @@ def test_eagle_correctness(
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
'''
|
||||
pytest.skip("To be aligned with GPU")
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
# NOTE: e2e of eagle has many problems before.
|
||||
# We first check whether it is functioning properly.
|
||||
# Should fix the e2e with VllmRunner in future.
|
||||
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
gpu_memory_utilization=0.6,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
spec_outputs = runner.model.chat(test_prompts, sampling_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompts = [{
|
||||
"role": "user",
|
||||
"content": "Hello, my name is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The president of the United States is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The capital of France is"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "The future of AI is"
|
||||
}]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[prompt],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
) for prompt in prompts
|
||||
]
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=300,
|
||||
temperature=0.0,
|
||||
ignore_eos=False,
|
||||
)
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
data_parallel_size=1,
|
||||
disable_log_stats=False,
|
||||
max_model_len=4096,
|
||||
seed=1024,
|
||||
async_scheduling=True,
|
||||
compilation_config={
|
||||
"level": 3,
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_num_of_warmups": 1,
|
||||
"cudagraph_capture_sizes": [12],
|
||||
},
|
||||
speculative_config={
|
||||
"disable_padded_drafter_batch": False,
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
"draft_vocab_size": 128256,
|
||||
},
|
||||
)
|
||||
llm.generate(prompts, sampling_params)
|
||||
cleanup_dist_env_and_memory()
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
|
||||
@@ -26,6 +26,13 @@ class TestEagleProposerInitialization(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
self.mock_cpugpubuffer = patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
||||
self.mock_cpugpubuffer.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
|
||||
def test_initialization_eagle(self):
|
||||
self.vllm_config.speculative_config.method = "eagle"
|
||||
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
||||
@@ -44,7 +51,7 @@ class TestEagleProposerInitialization(TestBase):
|
||||
self.assertEqual(proposer.input_ids.shape, (1024, ))
|
||||
self.assertEqual(proposer.positions.shape, (1024, ))
|
||||
self.assertEqual(proposer.hidden_states.shape, (1024, 4096))
|
||||
self.assertEqual(proposer.arange.shape, (33, ))
|
||||
self.assertEqual(proposer.arange.shape, (1024, ))
|
||||
|
||||
def test_initialization_eagle3(self):
|
||||
self.vllm_config.speculative_config.method = "eagle3"
|
||||
@@ -77,10 +84,16 @@ class TestEagleProposerLoadModel(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
self.mock_cpugpubuffer = patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
||||
self.mock_cpugpubuffer.start()
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
def tearDown(self):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
|
||||
@@ -172,11 +185,17 @@ class TestEagleProposerDummyRun(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
self.mock_cpugpubuffer = patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
||||
self.mock_cpugpubuffer.start()
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
self.proposer.model = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||
def test_dummy_run_basic(self, mock_context):
|
||||
num_tokens = 32
|
||||
@@ -216,6 +235,9 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
self.mock_cpugpubuffer = patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
||||
self.mock_cpugpubuffer.start()
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
@@ -223,7 +245,12 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
||||
self.proposer._propose = MagicMock(
|
||||
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
|
||||
|
||||
def test_generate_token_ids_without_metadata(self):
|
||||
def tearDown(self):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
|
||||
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
||||
# We need to add some cases about disable_padded_drafter_batch=False in future.
|
||||
def test_generate_token_ids(self):
|
||||
valid_sampled = [[20, 30, 40]]
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = [2, 1, 3]
|
||||
@@ -239,7 +266,7 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
||||
return_value={"layer_0": mock_attn_metadata})
|
||||
|
||||
result = self.proposer.generate_token_ids(
|
||||
valid_sampled_token_ids=valid_sampled,
|
||||
sampled_token_ids=valid_sampled,
|
||||
scheduler_output=scheduler_output,
|
||||
positions=positions,
|
||||
num_scheduled_tokens=num_scheduled,
|
||||
@@ -247,36 +274,13 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
||||
)
|
||||
|
||||
self.proposer._propose.assert_called_once()
|
||||
self.assertEqual(result, [[1, 2], [3, 4], [5, 6]])
|
||||
|
||||
def test_generate_token_ids_with_metadata(self):
|
||||
valid_sampled = [[5], [6, 7], [8, 9, 10]]
|
||||
spec_metadata = MagicMock()
|
||||
spec_metadata.num_draft_tokens = [2, 3, 4]
|
||||
|
||||
mock_attn_metadata = MagicMock()
|
||||
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
|
||||
mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6])
|
||||
mock_attn_metadata.block_tables = MagicMock()
|
||||
self.proposer._get_eagle_atten_dict = MagicMock(
|
||||
return_value={"layer_0": mock_attn_metadata})
|
||||
self.proposer._prepare_inputs = MagicMock(
|
||||
return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5])))
|
||||
|
||||
result = self.proposer.generate_token_ids(
|
||||
valid_sampled_token_ids=valid_sampled,
|
||||
spec_decode_metadata=spec_metadata,
|
||||
positions=torch.randn(6, 1),
|
||||
hidden_states=torch.randn(6, 4096),
|
||||
)
|
||||
|
||||
self.proposer._prepare_inputs.assert_called_once()
|
||||
self.assertEqual(self.proposer._propose.call_count, 1)
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
|
||||
|
||||
|
||||
class TestEagleProposerHelperMethods(TestBase):
|
||||
|
||||
# TODO: Can add some tests about prepare_next_token_ids in future.
|
||||
|
||||
def setUp(self):
|
||||
self.vllm_config = MagicMock(spec=VllmConfig)
|
||||
self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3)
|
||||
@@ -293,21 +297,29 @@ class TestEagleProposerHelperMethods(TestBase):
|
||||
self.vllm_config.model_config.dtype = torch.float16
|
||||
self.vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
self.mock_cpugpubuffer = patch(
|
||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
||||
self.mock_cpugpubuffer.start()
|
||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
runner=self.runner)
|
||||
|
||||
def tearDown(self):
|
||||
self.mock_cpugpubuffer.stop()
|
||||
|
||||
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
||||
# We need to add a test_prepare_inputs_padded in future.
|
||||
def test_prepare_inputs(self):
|
||||
self.proposer.token_arange_np = np.arange(10)
|
||||
mock_attn = MagicMock()
|
||||
mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
|
||||
num_rejected = torch.tensor([1, 0, 1], device=self.device)
|
||||
mock_return_attn = MagicMock()
|
||||
|
||||
with patch.object(self.proposer,
|
||||
'_prepare_inputs',
|
||||
return_value=(torch.tensor([0, 2, 5]),
|
||||
'prepare_inputs',
|
||||
return_value=(mock_return_attn,
|
||||
torch.tensor([1, 2, 4]))):
|
||||
cu_num_tokens, indices = self.proposer._prepare_inputs(
|
||||
return_attn, indices = self.proposer.prepare_inputs(
|
||||
mock_attn, num_rejected)
|
||||
self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5])
|
||||
self.assertEqual(indices.tolist(), [1, 2, 4])
|
||||
|
||||
@@ -730,6 +730,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||
# If it's necessary, the slots should be sliced.
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||
@@ -742,7 +745,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||
return key, value
|
||||
|
||||
def forward_impl(
|
||||
|
||||
@@ -119,6 +119,35 @@ class AscendCommonAttentionMetadata:
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
# This only use to eagle now. It will be use to enforce_eager in future.
|
||||
return AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
||||
num_computed_tokens_cpu=self.
|
||||
num_computed_tokens_cpu[:num_actual_reqs],
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
is_only_prefill=self.is_only_prefill,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
prefill_context_parallel_metadata,
|
||||
)
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
seq_len: torch.Tensor,
|
||||
|
||||
@@ -14,19 +14,25 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
|
||||
class EagleProposer(Proposer):
|
||||
|
||||
@@ -54,6 +60,19 @@ class EagleProposer(Proposer):
|
||||
sorted(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Currently we do not use pcp. This is used to adapt the pcp branch.
|
||||
self.pcp_size = 0
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
with_numpy=True,
|
||||
)
|
||||
self.decode_threshold = 1 + \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
@@ -71,12 +90,13 @@ class EagleProposer(Proposer):
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
1,
|
||||
max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1)
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
@@ -135,8 +155,7 @@ class EagleProposer(Proposer):
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: torch.Tensor
|
||||
| list[list[int]],
|
||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||
@@ -144,273 +163,155 @@ class EagleProposer(Proposer):
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||
|
||||
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.runner.input_batch.req_ids[i]
|
||||
req_state = self.runner.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[self.attn_layer_name]
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids.gpu[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), \
|
||||
"sampled_token_ids should be a python list when" \
|
||||
"padded-batch is disabled."
|
||||
next_token_ids = self.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.runner.requests,
|
||||
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
|
||||
else:
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, token_indices =\
|
||||
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
|
||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||
"sampled_token_ids should be a torch.Tensor when" \
|
||||
"padded-batch is enabled."
|
||||
next_token_ids, valid_sampled_tokens_count = \
|
||||
self.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.runner.requests,
|
||||
self.runner.input_batch,
|
||||
self.runner.discard_request_indices.gpu,
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
else:
|
||||
long_seq_metadata = None
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids.gpu[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
common_attn_metadata.query_start_loc = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self.prepare_inputs(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
spec_decode_metadata.num_draft_tokens)
|
||||
else:
|
||||
common_attn_metadata, token_indices, \
|
||||
token_indices_to_sample =\
|
||||
self.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
|
||||
draft_token_ids = self._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_tables,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def _get_eagle_atten_dict(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.runner.input_batch.block_table.commit_block_table(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.runner.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
self.runner.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.runner.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.runner.positions.np[:total_num_scheduled_tokens]
|
||||
np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.runner.uses_mrope:
|
||||
self.runner._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (
|
||||
positions_np +
|
||||
req_indices * self.runner.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(
|
||||
self.runner.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.runner.input_ids.cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||
# attetnion layers in the model for now, so the current logic for
|
||||
# getting attn_metadata is not related to kv_cache_group information.
|
||||
# Will extend this part to support multiple KV cache groups later.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.runner.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table = self.runner.input_batch.block_table[
|
||||
kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping.np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.runner.query_start_loc.np[0] = 0
|
||||
self.runner.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
self.runner.seq_lens.np[:num_reqs] = (
|
||||
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the NPU.
|
||||
self.runner.input_ids.gpu[:total_num_scheduled_tokens].copy_(
|
||||
self.runner.input_ids.cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
if self.runner.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.runner.mrope_positions.gpu[:, :total_num_scheduled_tokens] \
|
||||
.copy_(
|
||||
self.runner.
|
||||
mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.runner.positions.gpu[:total_num_scheduled_tokens].copy_(
|
||||
self.runner.positions.cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
self.runner.query_start_loc.gpu[:num_reqs + 1].copy_(
|
||||
self.runner.query_start_loc.cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.runner.seq_lens.gpu[:num_reqs].copy_(
|
||||
self.runner.seq_lens.cpu[:num_reqs], non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.runner.seq_lens.gpu[num_reqs:].fill_(0)
|
||||
self.runner.query_start_loc.gpu[num_reqs + 1:].fill_(-1)
|
||||
|
||||
attn_metadata = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.runner.kv_cache_config.kv_cache_groups):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.runner.seq_lens.cpu,
|
||||
num_reqs=num_reqs,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping.gpu,
|
||||
positions=self.runner.positions.gpu,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_i = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
cumsum_dtype: Optional[np.dtype] = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the cumulative sum and batched arange of the given array.
|
||||
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
||||
# Equivalent to but faster than:
|
||||
# np.concatenate([np.arange(n) for n in num_tokens])
|
||||
"""
|
||||
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
||||
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
||||
total_num_tokens = cu_num_tokens[-1]
|
||||
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
||||
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets
|
||||
|
||||
return cu_num_tokens, arange
|
||||
return draft_token_ids
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||
torch.Tensor]] = None,
|
||||
req_scheduled_tokens=None,
|
||||
long_seq_metadata=None,
|
||||
num_prefill_reqs=0,
|
||||
num_decode_reqs=0,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
device = cu_num_tokens.device
|
||||
cu_num_tokens = cu_num_tokens.cpu()
|
||||
block_table = block_table.cpu()
|
||||
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
target_positions = target_positions.cpu()
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
@@ -423,34 +324,7 @@ class EagleProposer(Proposer):
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||
max_query_len = query_lens.max().item()
|
||||
attn_mask = self.runner.attn_mask
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens.to(device),
|
||||
query_start_loc_cpu=cu_num_tokens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
max_query_len=max_query_len,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping=target_slot_mapping,
|
||||
positions=target_positions,
|
||||
attn_mask=attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
@@ -458,9 +332,14 @@ class EagleProposer(Proposer):
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions.to(device)
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
attn_metadata.block_tables = block_table.to(device)
|
||||
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
@@ -482,12 +361,14 @@ class EagleProposer(Proposer):
|
||||
draft_token_ids_tensor = torch.zeros(
|
||||
(self.vllm_config.speculative_config.num_speculative_tokens,
|
||||
*draft_token_ids.shape),
|
||||
dtype=draft_token_ids.dtype)
|
||||
dtype=draft_token_ids.dtype,
|
||||
device=self.device)
|
||||
draft_token_ids_tensor[0] = draft_token_ids
|
||||
|
||||
positions_cpu = target_positions[last_token_indices].cpu().to(
|
||||
torch.int64)
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
@@ -496,16 +377,14 @@ class EagleProposer(Proposer):
|
||||
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
|
||||
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
|
||||
1:].tolist()
|
||||
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
|
||||
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
|
||||
query_lens.fill_(1)
|
||||
attn_metadata.query_lens = query_lens
|
||||
|
||||
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
|
||||
attn_metadata.seq_lens_list = seq_lens.tolist()
|
||||
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
|
||||
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
for now_speculative in range(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens -
|
||||
@@ -513,8 +392,8 @@ class EagleProposer(Proposer):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_tensor[now_speculative].to(device)
|
||||
positions_cpu += 1
|
||||
input_ids = draft_token_ids_tensor[now_speculative]
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
@@ -522,16 +401,15 @@ class EagleProposer(Proposer):
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len
|
||||
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions_cpu = torch.where(exceeds_max_model_len, 0,
|
||||
positions_cpu)
|
||||
clamped_positions = clamped_positions_cpu.to(device)
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
# TODO: Increment the sequence lengths.
|
||||
|
||||
attn_metadata.seq_lens += 1
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens + 1
|
||||
attn_metadata.seq_lens_list = [
|
||||
_ + 1 for _ in attn_metadata.seq_lens_list
|
||||
]
|
||||
@@ -542,22 +420,22 @@ class EagleProposer(Proposer):
|
||||
# TODO: sequence length to 1 to minimize their overheads in attention.
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = (clamped_positions_cpu // self.block_size)
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_numbers = (clamped_positions // self.block_size)
|
||||
block_ids = attn_metadata.block_tables.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
slot_mapping_cpu = (
|
||||
slot_mapping_tmp = (
|
||||
block_ids * self.vllm_config.cache_config.block_size +
|
||||
clamped_positions_cpu % self.block_size)
|
||||
clamped_positions % self.block_size)
|
||||
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping_cpu.masked_fill_(exceeds_max_model_len,
|
||||
slot_mapping_tmp.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
# NOTE: ASCEND slot_mapping must on cpu
|
||||
attn_metadata.slot_mapping = slot_mapping_cpu.to(
|
||||
torch.int32).to(device)
|
||||
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
|
||||
slot_mapping_tmp.to(torch.int32))
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
@@ -565,7 +443,6 @@ class EagleProposer(Proposer):
|
||||
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
attn_metadata.attn_mask = attn_mask
|
||||
attn_metadata.block_tables = block_table.to(device)
|
||||
# Run the model.
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -581,49 +458,188 @@ class EagleProposer(Proposer):
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu()
|
||||
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||
return draft_token_ids
|
||||
|
||||
def _prepare_inputs(
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
eagle_attn_metadata: AscendMetadata,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
sampled_token_ids: list[list[int]],
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids for each request based on the sampled
|
||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||
during the initial decoding steps), it falls back to using the request
|
||||
state to get the next token id.
|
||||
"""
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
||||
req_id]
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for the spec decode.
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item())
|
||||
for i in range(num_reqs)
|
||||
])
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = discard_request_indices[:
|
||||
num_discarded_requests]
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
valid_sampled_token_ids_gpu.index_fill_(
|
||||
0, discard_sampled_tokens_req_indices, -1)
|
||||
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1,
|
||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1,
|
||||
selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size],
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
||||
if self.runner.valid_sampled_token_count_event is not None:
|
||||
default_stream = torch.npu.current_stream()
|
||||
# initialize a new stream to overlap the copy operation with
|
||||
# prepare_input of draft model.
|
||||
with torch.npu.stream(
|
||||
self.runner.valid_sampled_token_count_copy_stream):
|
||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
||||
default_stream) # type: ignore
|
||||
self.runner.valid_sampled_token_count_cpu[:
|
||||
valid_sampled_tokens_count
|
||||
.shape[0]].copy_(
|
||||
valid_sampled_tokens_count,
|
||||
non_blocking=True
|
||||
)
|
||||
self.runner.valid_sampled_token_count_event.record()
|
||||
|
||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
||||
1)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
|
||||
cu_target_query_lens = eagle_attn_metadata.query_start_loc
|
||||
device = eagle_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = cu_target_query_lens.to("cpu")
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_actual_reqs = len(num_draft_tokens)
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||||
dtype=torch.int32)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_actual_reqs
|
||||
+ 1]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
|
||||
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
new_query_len_per_req = query_start_loc_cpu[
|
||||
1:] - query_start_loc_cpu[:-1]
|
||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
|
||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||||
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
||||
|
||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||
@@ -631,7 +647,8 @@ class EagleProposer(Proposer):
|
||||
new_query_start_loc_cpu = torch.zeros(
|
||||
query_start_loc_cpu.shape,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available())
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||
|
||||
@@ -646,8 +663,8 @@ class EagleProposer(Proposer):
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
token_offests = self.token_arange_np[:total_num_tokens] \
|
||||
- new_query_start_locs_expanded
|
||||
token_offests = (self.token_arange_np[:total_num_tokens] -
|
||||
new_query_start_locs_expanded)
|
||||
|
||||
# Expand starting positions to match token pattern
|
||||
# [0, q1, q1 + q2] ->
|
||||
@@ -656,21 +673,101 @@ class EagleProposer(Proposer):
|
||||
old_query_start_locs_expanded = np.repeat(
|
||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
# need use npu
|
||||
query_len_per_req = (cu_target_query_lens[1:] -
|
||||
cu_target_query_lens[:-1])
|
||||
num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||||
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
|
||||
common_attn_metadata.slot_mapping[token_indices])
|
||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
||||
|
||||
# [a - n1, b - n2, c - n3] ->
|
||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu.to(device,
|
||||
non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
seq_lens_cpu=new_seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
positions=common_attn_metadata.positions[token_indices],
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
return cu_num_tokens, token_indices
|
||||
def prepare_inputs_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu),
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[
|
||||
1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
positions=common_attn_metadata.positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
seq_lens=common_attn_metadata.seq_lens)
|
||||
|
||||
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
||||
1 - num_rejected_tokens_gpu)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
@@ -801,7 +801,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||||
]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
base_num_reqs = self.input_batch.num_reqs
|
||||
num_reqs = base_num_reqs
|
||||
if self.pcp_size > 1:
|
||||
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||||
# to calculate discard_requests_mask
|
||||
@@ -1106,6 +1107,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.speculative_config and \
|
||||
self.spec_decode_common_attn_metadata is None:
|
||||
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||||
if self.speculative_config.method in ("eagle", "eagle3") and \
|
||||
self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.spec_decode_common_attn_metadata = \
|
||||
self.spec_decode_common_attn_metadata.unpadded(
|
||||
total_num_scheduled_tokens, base_num_reqs)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
common_prefix_len = 0
|
||||
@@ -1591,7 +1597,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
with ProfileExecuteDuration().capture_async("Draft"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
self.speculative_config.method == "mtp" and \
|
||||
self.speculative_config.use_eagle() and \
|
||||
not self.speculative_config.disable_padded_drafter_batch
|
||||
if use_padded_batch_for_eagle:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
|
||||
Reference in New Issue
Block a user