[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -47,7 +47,6 @@ mtp_proposer.py
|
|||||||
├── Proposer
|
├── Proposer
|
||||||
│ ├── load_model
|
│ ├── load_model
|
||||||
│ ├── dummy_run
|
│ ├── dummy_run
|
||||||
│ ├── generate_token_ids
|
|
||||||
│ ├── _prepare_inputs
|
│ ├── _prepare_inputs
|
||||||
│ ├── _propose
|
│ ├── _propose
|
||||||
```
|
```
|
||||||
@@ -86,11 +85,11 @@ def get_spec_decode_method(method,
|
|||||||
device,
|
device,
|
||||||
runner):
|
runner):
|
||||||
if method == "ngram":
|
if method == "ngram":
|
||||||
return NgramProposer(vllm_config, device, runner)
|
return AscendNgramProposer(vllm_config, device, runner)
|
||||||
elif method in ["eagle", "eagle3"]:
|
elif method in ["eagle", "eagle3"]:
|
||||||
return EagleProposer(vllm_config, device, runner)
|
return AscendEagleProposer(vllm_config, device, runner)
|
||||||
elif method == 'mtp':
|
elif method == 'mtp':
|
||||||
return MtpProposer(vllm_config, device, runner)
|
return AscendMtpProposer(vllm_config, device, runner)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown speculative decoding method: "
|
raise ValueError("Unknown speculative decoding method: "
|
||||||
f"{method}")
|
f"{method}")
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig,
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
||||||
|
|
||||||
|
|
||||||
class TestEagleProposerInitialization(TestBase):
|
class TestEagleProposerInitialization(TestBase):
|
||||||
@@ -79,7 +78,7 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
init_ascend_config(self.vllm_config)
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
@@ -102,7 +101,7 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
init_ascend_config(self.vllm_config)
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
init_ascend_config(self.vllm_config)
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
@@ -140,7 +139,7 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
init_ascend_config(self.vllm_config)
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
@@ -196,7 +195,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
|
|
||||||
# Set the current vllm config
|
# Set the current vllm config
|
||||||
set_current_vllm_config(self.vllm_config)
|
set_current_vllm_config(self.vllm_config)
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
@@ -235,7 +234,6 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
mock_model.model.embed_tokens = MagicMock()
|
mock_model.model.embed_tokens = MagicMock()
|
||||||
mock_model.model.embed_tokens.weight = weight
|
mock_model.model.embed_tokens.weight = weight
|
||||||
|
|
||||||
self.proposer.name = SpecDcodeType.EAGLE
|
|
||||||
mock_get_model.return_value = MagicMock()
|
mock_get_model.return_value = MagicMock()
|
||||||
mock_get_model.return_value.model.embed_tokens.weight = weight
|
mock_get_model.return_value.model.embed_tokens.weight = weight
|
||||||
|
|
||||||
@@ -301,7 +299,6 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
mock_pp_group.return_value.world_size = 2
|
mock_pp_group.return_value.world_size = 2
|
||||||
|
|
||||||
self.proposer.model = MagicMock()
|
self.proposer.model = MagicMock()
|
||||||
self.proposer.name = SpecDcodeType.EAGLE
|
|
||||||
|
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
self.proposer.load_model(mock_model)
|
self.proposer.load_model(mock_model)
|
||||||
@@ -373,7 +370,7 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
|
|
||||||
# Set the current vllm config
|
# Set the current vllm config
|
||||||
set_current_vllm_config(self.vllm_config)
|
set_current_vllm_config(self.vllm_config)
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
self.proposer.model = MagicMock()
|
self.proposer.model = MagicMock()
|
||||||
@@ -514,7 +511,7 @@ class TestEagleProposerHelperMethods(TestBase):
|
|||||||
|
|
||||||
# Set the current vllm config
|
# Set the current vllm config
|
||||||
set_current_vllm_config(self.vllm_config)
|
set_current_vllm_config(self.vllm_config)
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||||
|
|
||||||
|
|
||||||
class TestMtpProposer:
|
class TestMtpProposer:
|
||||||
@@ -96,7 +96,7 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
# Test basic initialization
|
# Test basic initialization
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
assert proposer.vllm_config == vllm_config
|
assert proposer.vllm_config == vllm_config
|
||||||
assert proposer.device == torch.device("cpu")
|
assert proposer.device == torch.device("cpu")
|
||||||
@@ -118,7 +118,7 @@ class TestMtpProposer:
|
|||||||
vllm_config.scheduler_config.async_scheduling = False
|
vllm_config.scheduler_config.async_scheduling = False
|
||||||
vllm_config.speculative_config.enforce_eager = False
|
vllm_config.speculative_config.enforce_eager = False
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
assert proposer.use_cuda_graph is True
|
assert proposer.use_cuda_graph is True
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class TestMtpProposer:
|
|||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
mock_dp_group.return_value.world_size = 1
|
mock_dp_group.return_value.world_size = 1
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
# Mock _runnable to prevent actual execution
|
# Mock _runnable to prevent actual execution
|
||||||
proposer._runnable = MagicMock()
|
proposer._runnable = MagicMock()
|
||||||
@@ -165,7 +165,7 @@ class TestMtpProposer:
|
|||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
mock_dp_group.return_value.world_size = 1
|
mock_dp_group.return_value.world_size = 1
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
# Mock _runnable to prevent actual execution
|
# Mock _runnable to prevent actual execution
|
||||||
proposer._runnable = MagicMock()
|
proposer._runnable = MagicMock()
|
||||||
@@ -197,9 +197,9 @@ class TestMtpProposer:
|
|||||||
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
|
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
|
||||||
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
|
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
|
||||||
|
|
||||||
proposer = MagicMock(spec=MtpProposer)
|
proposer = MagicMock(spec=AscendMtpProposer)
|
||||||
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
||||||
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
|
proposer.prepare_next_token_ids_cpu = AscendMtpProposer.prepare_next_token_ids_cpu.__get__(
|
||||||
proposer)
|
proposer)
|
||||||
result = proposer.prepare_next_token_ids_cpu(
|
result = proposer.prepare_next_token_ids_cpu(
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
@@ -253,10 +253,10 @@ class TestMtpProposer:
|
|||||||
mock_backup.copy_to_gpu = MagicMock()
|
mock_backup.copy_to_gpu = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_backup
|
mock_cpu_gpu_buffer.return_value = mock_backup
|
||||||
|
|
||||||
proposer = MagicMock(spec=MtpProposer)
|
proposer = MagicMock(spec=AscendMtpProposer)
|
||||||
proposer.backup_next_token_ids = mock_backup
|
proposer.backup_next_token_ids = mock_backup
|
||||||
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
proposer.input_ids = MagicMock(device=torch.device("cpu"))
|
||||||
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
|
proposer.prepare_next_token_ids_padded = AscendMtpProposer.prepare_next_token_ids_padded.__get__(
|
||||||
proposer)
|
proposer)
|
||||||
|
|
||||||
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
|
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
|
||||||
@@ -327,11 +327,11 @@ class TestMtpProposer:
|
|||||||
mock_runner.pcp_size = 1
|
mock_runner.pcp_size = 1
|
||||||
mock_runner.decode_token_per_req = MagicMock()
|
mock_runner.decode_token_per_req = MagicMock()
|
||||||
|
|
||||||
proposer = MagicMock(spec=MtpProposer)
|
proposer = MagicMock(spec=AscendMtpProposer)
|
||||||
proposer.runner = mock_runner
|
proposer.runner = mock_runner
|
||||||
proposer.pcp_size = 1
|
proposer.pcp_size = 1
|
||||||
proposer.arange = torch.arange(100, dtype=torch.int32)
|
proposer.arange = torch.arange(100, dtype=torch.int32)
|
||||||
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
|
proposer.prepare_inputs_padded = AscendMtpProposer.prepare_inputs_padded.__get__(
|
||||||
proposer)
|
proposer)
|
||||||
|
|
||||||
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
|
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
|
||||||
|
|||||||
@@ -16,23 +16,23 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||||
#
|
#
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||||
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||||
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
|
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
|
||||||
|
|
||||||
|
|
||||||
def get_spec_decode_method(method, vllm_config, device, runner):
|
def get_spec_decode_method(method, vllm_config, device, runner):
|
||||||
if method == "ngram":
|
if method == "ngram":
|
||||||
return NgramProposer(vllm_config, device, runner)
|
return AscendNgramProposer(vllm_config, runner)
|
||||||
elif method in ("eagle", "eagle3"):
|
|
||||||
return EagleProposer(vllm_config, device, runner)
|
|
||||||
elif method == "mtp":
|
|
||||||
return MtpProposer(vllm_config, device, runner)
|
|
||||||
elif method == "suffix":
|
elif method == "suffix":
|
||||||
return SuffixDecodingProposer(vllm_config, device, runner)
|
return AscendSuffixDecodingProposer(vllm_config, runner)
|
||||||
elif method == "medusa":
|
elif method == "medusa":
|
||||||
return MedusaProposer(vllm_config, device, runner)
|
return AscendMedusaProposer(vllm_config, device)
|
||||||
|
elif method in ("eagle", "eagle3"):
|
||||||
|
return AscendEagleProposer(vllm_config, device, runner)
|
||||||
|
elif method == "mtp":
|
||||||
|
return AscendMtpProposer(vllm_config, device, runner)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown speculative decoding method: {method}")
|
raise ValueError(f"Unknown speculative decoding method: {method}")
|
||||||
|
|||||||
@@ -30,8 +30,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
|||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
|
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||||
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
@@ -81,7 +80,7 @@ def split_inputs_tp_to_sp(hidden_states, out):
|
|||||||
return out[:padded_num_tokens_per_rank]
|
return out[:padded_num_tokens_per_rank]
|
||||||
|
|
||||||
|
|
||||||
class EagleProposer(VllmEagleProposer):
|
class AscendEagleProposer(EagleProposer):
|
||||||
_runnable: ACLGraphWrapper | Callable
|
_runnable: ACLGraphWrapper | Callable
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
import enum
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class SpecDcodeType(enum.Enum):
|
|
||||||
NGRAM = 0
|
|
||||||
EAGLE = 1
|
|
||||||
EAGLE3 = 2
|
|
||||||
MTP = 4
|
|
||||||
SUFFIX = 5
|
|
||||||
MEDUSA = 6
|
|
||||||
|
|
||||||
|
|
||||||
class Proposer:
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device = None, runner=None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_model(self, model):
|
|
||||||
"""Called by load_model in model_runner"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def dummy_run(
|
|
||||||
self,
|
|
||||||
num_tokens: int,
|
|
||||||
with_prefill: bool = False,
|
|
||||||
in_graph_capturing: bool = False,
|
|
||||||
num_reqs: int = 0,
|
|
||||||
num_tokens_across_dp: torch.Tensor | None = None,
|
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
|
||||||
batch_descriptor=None,
|
|
||||||
):
|
|
||||||
"""Called by dummy_run in model_runner"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def generate_token_ids(
|
|
||||||
self,
|
|
||||||
valid_sampled_token_ids: list[list[int]],
|
|
||||||
sampling_metadata: SamplingMetadata = None,
|
|
||||||
scheduler_output: SchedulerOutput = None,
|
|
||||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
|
||||||
positions: torch.Tensor = None,
|
|
||||||
num_scheduled_tokens: int = 0,
|
|
||||||
hidden_states: torch.Tensor = None,
|
|
||||||
aux_hidden_states: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
"""Called by execute_model in model_runner"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,36 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.medusa import MedusaProposer as VllmMedusaProposer
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MedusaProposer(VllmMedusaProposer):
|
class AscendMedusaProposer(MedusaProposer):
|
||||||
"""
|
"""
|
||||||
Medusa proposer class for generating token sequences
|
Medusa proposer class for generating token sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
runner,
|
|
||||||
):
|
|
||||||
# Save config parameters
|
|
||||||
self.name = SpecDcodeType.MEDUSA
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.device = device
|
|
||||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
||||||
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size()
|
|
||||||
self.dtype = vllm_config.model_config.dtype
|
|
||||||
self.runner = runner
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
@@ -62,14 +46,12 @@ class MedusaProposer(VllmMedusaProposer):
|
|||||||
self.model(hidden_states)
|
self.model(hidden_states)
|
||||||
dummy_compute_logits(hidden_states)
|
dummy_compute_logits(hidden_states)
|
||||||
|
|
||||||
def generate_token_ids(
|
def propose(
|
||||||
self,
|
self,
|
||||||
valid_sampled_token_ids: list[list[int]],
|
valid_sampled_token_ids: list[list[int]],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
spec_decode_metadata: SpecDecodeMetadata,
|
spec_decode_metadata: SpecDecodeMetadata,
|
||||||
sample_hidden_states: torch.Tensor,
|
sample_hidden_states: torch.Tensor,
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
|
if sample_hidden_states.shape[0] == len(valid_sampled_token_ids):
|
||||||
# The input to the target model does not include draft tokens.
|
# The input to the target model does not include draft tokens.
|
||||||
@@ -84,7 +66,7 @@ class MedusaProposer(VllmMedusaProposer):
|
|||||||
indices = offsets + num_accepted_tokens - 1
|
indices = offsets + num_accepted_tokens - 1
|
||||||
hidden_states = sample_hidden_states[indices]
|
hidden_states = sample_hidden_states[indices]
|
||||||
|
|
||||||
spec_token_ids = self.propose(
|
spec_token_ids = super().propose(
|
||||||
target_hidden_states=hidden_states,
|
target_hidden_states=hidden_states,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||||
from vllm_ascend.utils import lmhead_tp_enable
|
from vllm_ascend.utils import lmhead_tp_enable
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(EagleProposer):
|
class AscendMtpProposer(AscendEagleProposer):
|
||||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||||
model: nn.Module | ACLGraphWrapper
|
model: nn.Module | ACLGraphWrapper
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer as VllmNgramProposer
|
|
||||||
|
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
|
||||||
|
|
||||||
|
|
||||||
class NgramProposer(VllmNgramProposer, Proposer):
|
class AscendNgramProposer(NgramProposer):
|
||||||
def __init__(self, vllm_config, device, runner):
|
def __init__(self, vllm_config, runner):
|
||||||
super().__init__(vllm_config)
|
|
||||||
self.name = SpecDcodeType.NGRAM
|
|
||||||
self.device = device
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
super().__init__(vllm_config)
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
# No model to load.
|
# No model to load.
|
||||||
@@ -24,26 +19,22 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
|||||||
in_graph_capturing=None,
|
in_graph_capturing=None,
|
||||||
num_reqs=None,
|
num_reqs=None,
|
||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode=None,
|
||||||
batch_descriptor=None,
|
batch_descriptor=None,
|
||||||
dummy_compute_logits=lambda hidden_states: None,
|
dummy_compute_logits=lambda hidden_states: None,
|
||||||
is_profile=False,
|
is_profile=False,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_token_ids(
|
def propose(
|
||||||
self,
|
self,
|
||||||
valid_sampled_token_ids,
|
sampled_token_ids: list[list[int]],
|
||||||
sampling_metadata=None,
|
num_tokens_no_spec=None,
|
||||||
scheduler_output=None,
|
token_ids_cpu=None,
|
||||||
spec_decode_metadata=None,
|
slot_masks: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
|
||||||
positions=None,
|
|
||||||
num_scheduled_tokens=None,
|
|
||||||
hidden_states=None,
|
|
||||||
aux_hidden_states=None,
|
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
valid_ngram_requests = []
|
valid_ngram_requests = []
|
||||||
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||||
num_sampled_ids = len(sampled_ids)
|
num_sampled_ids = len(sampled_ids)
|
||||||
if not num_sampled_ids:
|
if not num_sampled_ids:
|
||||||
continue
|
continue
|
||||||
@@ -64,7 +55,7 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
|||||||
valid_ngram_requests.append(i)
|
valid_ngram_requests.append(i)
|
||||||
|
|
||||||
draft_token_ids = self.batch_propose(
|
draft_token_ids = self.batch_propose(
|
||||||
len(valid_sampled_token_ids),
|
len(sampled_token_ids),
|
||||||
valid_ngram_requests,
|
valid_ngram_requests,
|
||||||
self.runner.input_batch.num_tokens_no_spec,
|
self.runner.input_batch.num_tokens_no_spec,
|
||||||
self.runner.input_batch.token_ids_cpu,
|
self.runner.input_batch.token_ids_cpu,
|
||||||
|
|||||||
@@ -1,22 +1,11 @@
|
|||||||
import torch
|
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||||
from vllm.config import CUDAGraphMode
|
|
||||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer as VllmSuffixDecodingProposer
|
|
||||||
|
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
|
||||||
|
|
||||||
|
|
||||||
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
class AscendSuffixDecodingProposer(SuffixDecodingProposer):
|
||||||
def __init__(self, vllm_config, device, runner):
|
def __init__(self, vllm_config, runner):
|
||||||
super().__init__(vllm_config)
|
super().__init__(vllm_config)
|
||||||
self.name = SpecDcodeType.SUFFIX
|
|
||||||
self.device = device
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
|
||||||
# No model to load.
|
|
||||||
pass
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -24,23 +13,12 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
|||||||
in_graph_capturing=None,
|
in_graph_capturing=None,
|
||||||
num_reqs=None,
|
num_reqs=None,
|
||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode=None,
|
||||||
batch_descriptor=None,
|
batch_descriptor=None,
|
||||||
dummy_compute_logits=lambda hidden_states: None,
|
dummy_compute_logits=lambda hidden_states: None,
|
||||||
is_profile=False,
|
is_profile=False,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_token_ids(
|
def propose(self, valid_sampled_token_ids):
|
||||||
self,
|
return super().propose(self.runner.input_batch, valid_sampled_token_ids)
|
||||||
valid_sampled_token_ids,
|
|
||||||
sampling_metadata=None,
|
|
||||||
scheduler_output=None,
|
|
||||||
spec_decode_metadata=None,
|
|
||||||
positions=None,
|
|
||||||
num_scheduled_tokens=None,
|
|
||||||
hidden_states=None,
|
|
||||||
aux_hidden_states=None,
|
|
||||||
) -> list[list[int]]:
|
|
||||||
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
|
|
||||||
return draft_token_ids
|
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ from vllm.v1.sample.logits_processor import build_logitsprocs
|
|||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
||||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
|
||||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||||
from vllm.v1.utils import record_function_or_nullcontext
|
from vllm.v1.utils import record_function_or_nullcontext
|
||||||
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
|
||||||
@@ -109,9 +107,11 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
|||||||
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
|
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
|
||||||
from vllm_ascend.sample.sampler import AscendSampler
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||||
|
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||||
|
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
check_gdn_layer,
|
check_gdn_layer,
|
||||||
enable_sp,
|
enable_sp,
|
||||||
@@ -402,9 +402,14 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
def _set_up_drafter(self):
|
def _set_up_drafter(self):
|
||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
self.drafter: NgramProposer | EagleProposer | MtpProposer | SuffixDecodingProposer | MedusaProposer | None = (
|
self.drafter: (
|
||||||
None
|
AscendNgramProposer
|
||||||
)
|
| AscendEagleProposer
|
||||||
|
| AscendMtpProposer
|
||||||
|
| AscendSuffixDecodingProposer
|
||||||
|
| AscendMedusaProposer
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
self.actual_seq_lengths_q: list[int] = []
|
self.actual_seq_lengths_q: list[int] = []
|
||||||
self.decode_token_per_req = 1
|
self.decode_token_per_req = 1
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
@@ -414,7 +419,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.drafter = self._get_drafter()
|
self.drafter = self._get_drafter()
|
||||||
if self.speculative_config.method == "eagle3":
|
if self.speculative_config.method == "eagle3":
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, AscendEagleProposer)
|
||||||
self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state
|
self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state
|
||||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||||
self.actual_seq_lengths_q = list(
|
self.actual_seq_lengths_q = list(
|
||||||
@@ -946,152 +951,134 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: list[dict[str, Any]] | dict[str, Any],
|
|
||||||
aux_hidden_states: torch.Tensor = None,
|
aux_hidden_states: torch.Tensor = None,
|
||||||
sample_hidden_states: torch.Tensor = None,
|
sample_hidden_states: torch.Tensor = None,
|
||||||
) -> list[list[int]] | None:
|
) -> list[list[int]] | None:
|
||||||
if not self.drafter:
|
if not self.drafter:
|
||||||
# Speculative decoding is not enabled.
|
# Speculative decoding is not enabled.
|
||||||
draft_token_ids = None
|
draft_token_ids = None
|
||||||
else:
|
elif isinstance(self.drafter, (AscendNgramProposer, AscendSuffixDecodingProposer)):
|
||||||
if self.speculative_config.method in ("suffix", "ngram"):
|
draft_token_ids = self.drafter.propose(valid_sampled_token_ids)
|
||||||
draft_token_ids = self.drafter.generate_token_ids(
|
elif isinstance(self.drafter, AscendMedusaProposer):
|
||||||
valid_sampled_token_ids,
|
draft_token_ids = self.drafter.propose(
|
||||||
sampling_metadata,
|
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
||||||
scheduler_output,
|
)
|
||||||
spec_decode_metadata,
|
elif self.speculative_config.use_eagle():
|
||||||
positions,
|
common_attn_metadata = spec_decode_common_attn_metadata
|
||||||
num_scheduled_tokens,
|
sampled_token_ids = valid_sampled_token_ids
|
||||||
hidden_states,
|
|
||||||
aux_hidden_states,
|
|
||||||
)
|
|
||||||
elif isinstance(self.drafter, MedusaProposer):
|
|
||||||
draft_token_ids = self.drafter.generate_token_ids(
|
|
||||||
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
|
||||||
)
|
|
||||||
elif self.speculative_config.use_eagle():
|
|
||||||
common_attn_metadata = spec_decode_common_attn_metadata
|
|
||||||
sampled_token_ids = valid_sampled_token_ids
|
|
||||||
|
|
||||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||||
# When padded-batch is disabled, the sampled_token_ids should be
|
# When padded-batch is disabled, the sampled_token_ids should be
|
||||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||||
# request, with invalid requests having empty lists.
|
# request, with invalid requests having empty lists.
|
||||||
assert isinstance(sampled_token_ids, list), (
|
assert isinstance(sampled_token_ids, list), (
|
||||||
"sampled_token_ids should be a python list whenpadded-batch is disabled."
|
"sampled_token_ids should be a python list whenpadded-batch is disabled."
|
||||||
)
|
)
|
||||||
assert self.drafter is not None
|
|
||||||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
|
||||||
sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# When using padded-batch, the sampled_token_ids should be
|
|
||||||
# the gpu tensor of sampled tokens for each request, of shape
|
|
||||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
|
||||||
# value -1.
|
|
||||||
assert isinstance(sampled_token_ids, torch.Tensor), (
|
|
||||||
"sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled."
|
|
||||||
)
|
|
||||||
assert self.drafter is not None
|
|
||||||
next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded(
|
|
||||||
common_attn_metadata,
|
|
||||||
sampled_token_ids,
|
|
||||||
self.requests,
|
|
||||||
self.input_batch,
|
|
||||||
self.discard_request_indices.gpu,
|
|
||||||
self.num_discarded_requests,
|
|
||||||
)
|
|
||||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
|
||||||
|
|
||||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
|
||||||
if self.use_cp:
|
|
||||||
long_seq_metadata = self.long_seq_metadata # type: ignore
|
|
||||||
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
|
||||||
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
|
|
||||||
num_decode_reqs = self.pcp_manager.num_decode_reqs
|
|
||||||
else:
|
|
||||||
long_seq_metadata = None # type: ignore
|
|
||||||
num_prefill_reqs = 0
|
|
||||||
num_decode_reqs = 0
|
|
||||||
if spec_decode_metadata is None:
|
|
||||||
# update pcp related params
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1
|
|
||||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
|
||||||
target_positions = self._get_positions(num_scheduled_tokens)
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
target_hidden_states = torch.cat(
|
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
token_indices_to_sample = None
|
|
||||||
# input_ids can be None for multimodal models.
|
|
||||||
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
|
||||||
target_positions = self._get_positions(num_scheduled_tokens)
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
target_hidden_states = torch.cat(
|
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
||||||
else:
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
assert common_attn_metadata is not None
|
|
||||||
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[
|
|
||||||
: num_reqs + 1
|
|
||||||
]
|
|
||||||
assert common_attn_metadata is not None
|
|
||||||
common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1]
|
|
||||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
|
||||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
|
||||||
token_indices_to_sample = None
|
|
||||||
assert self.drafter is not None
|
|
||||||
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
|
||||||
common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.drafter is not None
|
|
||||||
common_attn_metadata, token_indices, token_indices_to_sample = (
|
|
||||||
self.drafter.prepare_inputs_padded(
|
|
||||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
target_token_ids = input_ids_pcp_full[token_indices]
|
|
||||||
target_positions = positions
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
|
||||||
else:
|
|
||||||
target_token_ids = self.input_ids.gpu[token_indices]
|
|
||||||
target_positions = self._get_positions(token_indices)
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
|
||||||
else:
|
|
||||||
target_hidden_states = hidden_states[token_indices]
|
|
||||||
assert self.drafter is not None
|
assert self.drafter is not None
|
||||||
draft_token_ids = self.drafter._propose(
|
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||||
target_token_ids=target_token_ids,
|
sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens
|
||||||
target_positions=target_positions,
|
|
||||||
target_hidden_states=target_hidden_states,
|
|
||||||
next_token_ids=next_token_ids,
|
|
||||||
last_token_indices=token_indices_to_sample,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
req_scheduled_tokens=req_scheduled_tokens,
|
|
||||||
long_seq_metadata=long_seq_metadata,
|
|
||||||
num_prefill_reqs=num_prefill_reqs,
|
|
||||||
num_decode_reqs=num_decode_reqs,
|
|
||||||
scheduler_output=scheduler_output,
|
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
# When using padded-batch, the sampled_token_ids should be
|
||||||
|
# the gpu tensor of sampled tokens for each request, of shape
|
||||||
|
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||||
|
# value -1.
|
||||||
|
assert isinstance(sampled_token_ids, torch.Tensor), (
|
||||||
|
"sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled."
|
||||||
|
)
|
||||||
|
assert self.drafter is not None
|
||||||
|
next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded(
|
||||||
|
common_attn_metadata,
|
||||||
|
sampled_token_ids,
|
||||||
|
self.requests,
|
||||||
|
self.input_batch,
|
||||||
|
self.discard_request_indices.gpu,
|
||||||
|
self.num_discarded_requests,
|
||||||
|
)
|
||||||
|
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||||
|
|
||||||
|
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
if self.use_cp:
|
||||||
|
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||||
|
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||||
|
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||||
|
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
|
||||||
|
num_decode_reqs = self.pcp_manager.num_decode_reqs
|
||||||
|
else:
|
||||||
|
long_seq_metadata = None # type: ignore
|
||||||
|
num_prefill_reqs = 0
|
||||||
|
num_decode_reqs = 0
|
||||||
|
if spec_decode_metadata is None:
|
||||||
|
# update pcp related params
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1
|
||||||
|
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||||
|
target_positions = self._get_positions(num_scheduled_tokens)
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1)
|
||||||
|
else:
|
||||||
|
token_indices_to_sample = None
|
||||||
|
# input_ids can be None for multimodal models.
|
||||||
|
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||||||
|
target_positions = self._get_positions(num_scheduled_tokens)
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1)
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
|
else:
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
assert common_attn_metadata is not None
|
||||||
|
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[
|
||||||
|
: num_reqs + 1
|
||||||
|
]
|
||||||
|
assert common_attn_metadata is not None
|
||||||
|
common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1]
|
||||||
|
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||||
|
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||||
|
token_indices_to_sample = None
|
||||||
|
assert self.drafter is not None
|
||||||
|
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
||||||
|
common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.drafter is not None
|
||||||
|
common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded(
|
||||||
|
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||||
|
)
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
target_token_ids = input_ids_pcp_full[token_indices]
|
||||||
|
target_positions = positions
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
|
else:
|
||||||
|
target_token_ids = self.input_ids.gpu[token_indices]
|
||||||
|
target_positions = self._get_positions(token_indices)
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[token_indices]
|
||||||
|
assert self.drafter is not None
|
||||||
|
draft_token_ids = self.drafter._propose(
|
||||||
|
target_token_ids=target_token_ids,
|
||||||
|
target_positions=target_positions,
|
||||||
|
target_hidden_states=target_hidden_states,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
last_token_indices=token_indices_to_sample,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
req_scheduled_tokens=req_scheduled_tokens,
|
||||||
|
long_seq_metadata=long_seq_metadata,
|
||||||
|
num_prefill_reqs=num_prefill_reqs,
|
||||||
|
num_decode_reqs=num_decode_reqs,
|
||||||
|
scheduler_output=scheduler_output,
|
||||||
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
@@ -1460,7 +1447,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
positions,
|
positions,
|
||||||
scheduler_output.total_num_scheduled_tokens,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attn_metadata,
|
|
||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
sample_hidden_states,
|
sample_hidden_states,
|
||||||
)
|
)
|
||||||
@@ -2088,7 +2074,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
if kv_cache_gid > 0:
|
if kv_cache_gid > 0:
|
||||||
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
||||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||||
if isinstance(self.drafter, EagleProposer):
|
if isinstance(self.drafter, AscendEagleProposer):
|
||||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||||
spec_decode_common_attn_metadata = cm
|
spec_decode_common_attn_metadata = cm
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user