[Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176)
### What this PR does / why we need it?
This PR aims to refactor eagle-related modules in vllm-ascend.
This is the starting PR of eagle refactoring. Provided with vllm-eagle,
ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from
ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a
initialization, we just delete `__init__` in mtp_proposer and simplify
the corresponding logic in eagle_proposer.
Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target
is to gradually delete ascend-mtp and enable ascend-eagle to converge to
vllm-eagle. So the main workspace is eagle_proposer. In this way, we
hope that contributors can concurrently refactor eagle.
Incoming changes:
1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp
2. delete `load_model` in mtp_proposer
3. delete `dummy_run` and `propose` in mtp_proposer
4. ......
RFC: #5467
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
|||||||
from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig
|
from vllm.config import CacheConfig, CompilationMode, CUDAGraphMode, VllmConfig
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||||
|
|
||||||
@@ -25,13 +26,24 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
])
|
||||||
|
self.vllm_config.additional_config = None
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
self.mock_cpugpubuffer = patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
self.mock_cpugpubuffer.start()
|
self.mock_cpugpubuffer.start()
|
||||||
|
self.mock_supports_multimodal_inputs = patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
)
|
||||||
|
self.mock_supports_multimodal_inputs.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
def test_initialization_eagle_graph(self):
|
def test_initialization_eagle_graph(self):
|
||||||
self.vllm_config.speculative_config.method = "eagle"
|
self.vllm_config.speculative_config.method = "eagle"
|
||||||
@@ -40,12 +52,12 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.model_config.enforce_eager = False
|
self.vllm_config.model_config.enforce_eager = False
|
||||||
self.vllm_config.speculative_config.enforce_eager = False
|
self.vllm_config.speculative_config.enforce_eager = False
|
||||||
self.vllm_config.scheduler_config.async_scheduling = False
|
self.vllm_config.scheduler_config.async_scheduling = False
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
self.assertEqual(proposer.name, SpecDcodeType.EAGLE)
|
|
||||||
self.assertEqual(proposer.block_size, 16)
|
self.assertEqual(proposer.block_size, 16)
|
||||||
self.assertEqual(proposer.hidden_size, 4096)
|
self.assertEqual(proposer.hidden_size, 4096)
|
||||||
self.assertTrue(proposer.use_cuda_graph)
|
self.assertTrue(proposer.use_cuda_graph)
|
||||||
@@ -60,12 +72,12 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048
|
||||||
self.vllm_config.compilation_config.mode = CompilationMode.NONE
|
self.vllm_config.compilation_config.mode = CompilationMode.NONE
|
||||||
self.vllm_config.model_config.enforce_eager = True
|
self.vllm_config.model_config.enforce_eager = True
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
self.assertEqual(proposer.name, SpecDcodeType.EAGLE3)
|
|
||||||
self.assertEqual(proposer.hidden_size, 2048)
|
self.assertEqual(proposer.hidden_size, 2048)
|
||||||
self.assertFalse(proposer.use_cuda_graph)
|
self.assertFalse(proposer.use_cuda_graph)
|
||||||
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
|
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
|
||||||
@@ -77,12 +89,12 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.model_config.enforce_eager = False
|
self.vllm_config.model_config.enforce_eager = False
|
||||||
self.vllm_config.speculative_config.enforce_eager = False
|
self.vllm_config.speculative_config.enforce_eager = False
|
||||||
self.vllm_config.scheduler_config.async_scheduling = True
|
self.vllm_config.scheduler_config.async_scheduling = True
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
proposer = EagleProposer(vllm_config=self.vllm_config,
|
proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
self.assertEqual(proposer.name, SpecDcodeType.EAGLE3)
|
|
||||||
self.assertEqual(proposer.hidden_size, 2048)
|
self.assertEqual(proposer.hidden_size, 2048)
|
||||||
self.assertFalse(proposer.use_cuda_graph)
|
self.assertFalse(proposer.use_cuda_graph)
|
||||||
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
|
self.assertEqual(proposer.hidden_states.shape, (1024, 2048))
|
||||||
@@ -102,16 +114,28 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
])
|
||||||
|
self.vllm_config.additional_config = None
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
self.mock_cpugpubuffer = patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
self.mock_cpugpubuffer.start()
|
self.mock_cpugpubuffer.start()
|
||||||
|
self.mock_supports_multimodal_inputs = patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
)
|
||||||
|
self.mock_supports_multimodal_inputs.start()
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||||
@@ -204,10 +228,20 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(4)
|
||||||
|
])
|
||||||
|
self.vllm_config.additional_config = None
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
self.mock_cpugpubuffer = patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
self.mock_cpugpubuffer.start()
|
self.mock_cpugpubuffer.start()
|
||||||
|
self.mock_supports_multimodal_inputs = patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
)
|
||||||
|
self.mock_supports_multimodal_inputs.start()
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
@@ -216,6 +250,7 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
|
||||||
@@ -287,16 +322,28 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
|||||||
1: MagicMock(get_token_id=lambda x: 101),
|
1: MagicMock(get_token_id=lambda x: 101),
|
||||||
2: MagicMock(get_token_id=lambda x: 102),
|
2: MagicMock(get_token_id=lambda x: 102),
|
||||||
}
|
}
|
||||||
|
self.runner.pcp_size = 1
|
||||||
|
|
||||||
self.vllm_config.cache_config.block_size = 16
|
self.vllm_config.cache_config.block_size = 16
|
||||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
])
|
||||||
|
self.vllm_config.additional_config = None
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
self.mock_cpugpubuffer = patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
self.mock_cpugpubuffer.start()
|
self.mock_cpugpubuffer.start()
|
||||||
|
self.mock_supports_multimodal_inputs = patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
)
|
||||||
|
self.mock_supports_multimodal_inputs.start()
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
@@ -306,6 +353,7 @@ class TestEagleProposerGenerateTokenIds(TestBase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
||||||
# We need to add some cases about disable_padded_drafter_batch=False in future.
|
# We need to add some cases about disable_padded_drafter_batch=False in future.
|
||||||
@@ -355,16 +403,28 @@ class TestEagleProposerHelperMethods(TestBase):
|
|||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
self.vllm_config.scheduler_config.max_num_seqs = 32
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
||||||
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
])
|
||||||
|
self.vllm_config.additional_config = None
|
||||||
|
init_ascend_config(self.vllm_config)
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
self.mock_cpugpubuffer = patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer")
|
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
self.mock_cpugpubuffer.start()
|
self.mock_cpugpubuffer.start()
|
||||||
|
self.mock_supports_multimodal_inputs = patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
)
|
||||||
|
self.mock_supports_multimodal_inputs.start()
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
runner=self.runner)
|
runner=self.runner)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.mock_cpugpubuffer.stop()
|
self.mock_cpugpubuffer.stop()
|
||||||
|
self.mock_supports_multimodal_inputs.stop()
|
||||||
|
|
||||||
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
||||||
# We need to add a test_prepare_inputs_padded in future.
|
# We need to add a test_prepare_inputs_padded in future.
|
||||||
|
|||||||
@@ -16,12 +16,18 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
|
|
||||||
|
|
||||||
class TestMtpProposer:
|
class TestMtpProposer:
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_supports_multimodal_inputs(self):
|
||||||
|
with patch(
|
||||||
|
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vllm_config(self):
|
def vllm_config(self):
|
||||||
config = MagicMock(spec=VllmConfig)
|
config = MagicMock(spec=VllmConfig)
|
||||||
@@ -31,6 +37,9 @@ class TestMtpProposer:
|
|||||||
config.speculative_config.method = "deepseek_mtp"
|
config.speculative_config.method = "deepseek_mtp"
|
||||||
config.speculative_config.draft_model_config = MagicMock()
|
config.speculative_config.draft_model_config = MagicMock()
|
||||||
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
||||||
|
config.speculative_config.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(2)
|
||||||
|
])
|
||||||
|
|
||||||
config.model_config = MagicMock(spec=ModelConfig)
|
config.model_config = MagicMock(spec=ModelConfig)
|
||||||
config.model_config.dtype = torch.float16
|
config.model_config.dtype = torch.float16
|
||||||
@@ -68,7 +77,7 @@ class TestMtpProposer:
|
|||||||
runner.reserved_mc2_mask = None
|
runner.reserved_mc2_mask = None
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner):
|
def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
@@ -76,7 +85,6 @@ class TestMtpProposer:
|
|||||||
# Test basic initialization
|
# Test basic initialization
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
assert proposer.name == SpecDcodeType.MTP
|
|
||||||
assert proposer.vllm_config == vllm_config
|
assert proposer.vllm_config == vllm_config
|
||||||
assert proposer.device == torch.device("cpu")
|
assert proposer.device == torch.device("cpu")
|
||||||
assert proposer.dtype == torch.float16
|
assert proposer.dtype == torch.float16
|
||||||
@@ -89,7 +97,7 @@ class TestMtpProposer:
|
|||||||
assert not hasattr(proposer, "mrope_positions")
|
assert not hasattr(proposer, "mrope_positions")
|
||||||
assert proposer.use_sparse is False
|
assert proposer.use_sparse is False
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config,
|
def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config,
|
||||||
runner):
|
runner):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
@@ -105,7 +113,7 @@ class TestMtpProposer:
|
|||||||
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
|
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
|
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
|
||||||
mock_set_dtype, mock_process_weights, mock_get_loader,
|
mock_set_dtype, mock_process_weights, mock_get_loader,
|
||||||
mock_get_layers, vllm_config, runner):
|
mock_get_layers, vllm_config, runner):
|
||||||
@@ -148,7 +156,7 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context,
|
def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context,
|
||||||
mock_get_forward_context, vllm_config, runner):
|
mock_get_forward_context, vllm_config, runner):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
@@ -173,7 +181,7 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context,
|
def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context,
|
||||||
mock_get_forward_context, vllm_config,
|
mock_get_forward_context, vllm_config,
|
||||||
runner):
|
runner):
|
||||||
@@ -201,7 +209,7 @@ class TestMtpProposer:
|
|||||||
# Check that model was called correct number of times
|
# Check that model was called correct number of times
|
||||||
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
|
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
|
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
@@ -272,7 +280,7 @@ class TestMtpProposer:
|
|||||||
proposer._propose.assert_called_once()
|
proposer._propose.assert_called_once()
|
||||||
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
|
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
@@ -295,7 +303,7 @@ class TestMtpProposer:
|
|||||||
assert torch.all(
|
assert torch.all(
|
||||||
result == torch.tensor([30, 50, 60], dtype=torch.int32))
|
result == torch.tensor([30, 50, 60], dtype=torch.int32))
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer):
|
def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer):
|
||||||
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
|
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
|
||||||
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
|
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
|
||||||
@@ -377,7 +385,7 @@ class TestMtpProposer:
|
|||||||
device=torch.device("cpu"))
|
device=torch.device("cpu"))
|
||||||
assert torch.equal(next_token_ids, expected_next_tokens)
|
assert torch.equal(next_token_ids, expected_next_tokens)
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
|
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
|||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
@@ -29,7 +29,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
update_attn_params)
|
update_attn_params)
|
||||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
@@ -38,29 +38,15 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
|||||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||||
|
|
||||||
|
|
||||||
class EagleProposer(Proposer):
|
class EagleProposer(VllmEagleProposer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
runner=None):
|
runner=None):
|
||||||
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
|
super().__init__(vllm_config, device, runner)
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.device = device
|
|
||||||
self.runner = runner
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.draft_model_config = self.speculative_config.draft_model_config
|
|
||||||
self.method = self.speculative_config.method
|
|
||||||
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
|
||||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||||
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
|
||||||
# We need to get the hidden size from the draft model config because
|
|
||||||
# the draft model's hidden size can be different from the target model's
|
|
||||||
# hidden size (e.g., Llama 3.3 70B).
|
|
||||||
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
|
|
||||||
)
|
|
||||||
|
|
||||||
# there is synchronization between mtp steps when enabling aclgraph,
|
# there is synchronization between mtp steps when enabling aclgraph,
|
||||||
# disable aclgraph when use async scheduling to avoid the
|
# disable aclgraph when use async scheduling to avoid the
|
||||||
# synchronization overhead.
|
# synchronization overhead.
|
||||||
@@ -77,45 +63,28 @@ class EagleProposer(Proposer):
|
|||||||
sorted(
|
sorted(
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
|
|
||||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
self.pcp_size = self.runner.pcp_size
|
||||||
# Currently we do not use pcp. This is used to adapt the pcp branch.
|
|
||||||
self.pcp_size = 0
|
|
||||||
self.backup_next_token_ids = CpuGpuBuffer(
|
|
||||||
max_batch_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=is_pin_memory_available(),
|
|
||||||
device=device,
|
|
||||||
with_numpy=True,
|
|
||||||
)
|
|
||||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||||
|
|
||||||
# persistent buffers for cuda graph
|
self.arange_cpu = torch.arange(self.arange.shape[0],
|
||||||
self.input_ids = torch.zeros(
|
|
||||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
self.positions = torch.zeros(
|
|
||||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device)
|
|
||||||
self.hidden_states = torch.zeros(
|
|
||||||
(self.vllm_config.scheduler_config.max_num_batched_tokens,
|
|
||||||
self.hidden_size),
|
|
||||||
dtype=self.vllm_config.model_config.dtype,
|
|
||||||
device=device)
|
|
||||||
self.max_num_tokens = (
|
|
||||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
|
||||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
|
||||||
max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1)
|
|
||||||
self.arange = torch.arange(max_num_slots_for_arange,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
self.eagle3_use_aux_hidden_state: bool = (
|
|
||||||
self._get_eagle3_use_aux_hidden_state_from_config())
|
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
||||||
|
|
||||||
|
self.dcp_size = self.runner.dcp_size
|
||||||
|
self.pcp_rank = self.runner.pcp_rank
|
||||||
|
self.dcp_rank = self.runner.dcp_rank
|
||||||
|
|
||||||
|
self.use_aclgraph = self.runner._use_aclgraph()
|
||||||
|
|
||||||
|
self.full_indices = range(
|
||||||
|
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||||
|
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||||
|
|
||||||
|
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||||
|
"index_topk")
|
||||||
|
|
||||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -165,7 +134,7 @@ class EagleProposer(Proposer):
|
|||||||
# share lm_head with the target model if needed
|
# share lm_head with the target model if needed
|
||||||
# some model definition do not define lm_head explicitly
|
# some model definition do not define lm_head explicitly
|
||||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||||
if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"):
|
if self.method == "eagle" and hasattr(model, "lm_head"):
|
||||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||||
if supports_multimodal(model):
|
if supports_multimodal(model):
|
||||||
self.model.lm_head = model.get_language_model().lm_head
|
self.model.lm_head = model.get_language_model().lm_head
|
||||||
@@ -337,7 +306,7 @@ class EagleProposer(Proposer):
|
|||||||
target_token_ids = self.runner.input_ids.gpu[:
|
target_token_ids = self.runner.input_ids.gpu[:
|
||||||
num_scheduled_tokens]
|
num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
if self.name == SpecDcodeType.EAGLE3:
|
if self.method == "eagle3":
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
@@ -371,7 +340,7 @@ class EagleProposer(Proposer):
|
|||||||
else:
|
else:
|
||||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||||
target_positions = positions[token_indices]
|
target_positions = positions[token_indices]
|
||||||
if self.name == SpecDcodeType.EAGLE3:
|
if self.method == "eagle3":
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
else:
|
else:
|
||||||
@@ -424,7 +393,7 @@ class EagleProposer(Proposer):
|
|||||||
if last_token_indices is None:
|
if last_token_indices is None:
|
||||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||||
|
|
||||||
if self.name == SpecDcodeType.EAGLE3:
|
if self.method == "eagle3":
|
||||||
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
|
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
|
||||||
target_hidden_states = self.model.combine_hidden_states(
|
target_hidden_states = self.model.combine_hidden_states(
|
||||||
target_hidden_states)
|
target_hidden_states)
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
|
||||||
get_layers_from_vllm_config, set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
from vllm.distributed import get_pcp_group
|
from vllm.distributed import get_pcp_group
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -20,12 +20,10 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
|||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
CommonAttentionMetadata)
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
@@ -35,9 +33,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
|||||||
update_mla_attn_dcp_pcp_params,
|
update_mla_attn_dcp_pcp_params,
|
||||||
update_mla_attn_params)
|
update_mla_attn_params)
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||||
shared_expert_dp_enabled)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -64,102 +61,11 @@ def _load_model(architecture):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(Proposer):
|
class MtpProposer(EagleProposer):
|
||||||
|
|
||||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||||
model: Union[nn.Module, ACLGraphWrapper]
|
model: Union[nn.Module, ACLGraphWrapper]
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device,
|
|
||||||
runner,
|
|
||||||
):
|
|
||||||
self.name = SpecDcodeType.MTP
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
assert self.speculative_config is not None
|
|
||||||
self.draft_model_config = self.speculative_config.draft_model_config
|
|
||||||
self.method = self.speculative_config.method
|
|
||||||
|
|
||||||
self.runner = runner
|
|
||||||
self.device = device
|
|
||||||
self.dtype = vllm_config.model_config.dtype
|
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
|
||||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
|
||||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
|
||||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
||||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
|
||||||
# We need to get the hidden size from the draft model config because
|
|
||||||
# the draft model's hidden size can be different from the target model's
|
|
||||||
# hidden size (e.g., Llama 3.3 70B).
|
|
||||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
|
||||||
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
|
||||||
|
|
||||||
self.pcp_size = self.runner.pcp_size
|
|
||||||
self.dcp_size = self.runner.dcp_size
|
|
||||||
self.pcp_rank = self.runner.pcp_rank
|
|
||||||
self.dcp_rank = self.runner.dcp_rank
|
|
||||||
|
|
||||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
|
||||||
self.draft_indexer_metadata_builder: Optional[
|
|
||||||
AttentionMetadataBuilder] = None
|
|
||||||
self.attn_layer_names: list[str] = []
|
|
||||||
self.indexer_layer_names: list[str] = []
|
|
||||||
|
|
||||||
self.use_aclgraph = self.runner._use_aclgraph()
|
|
||||||
|
|
||||||
# persistent buffers for aclgraph graph
|
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
|
||||||
if self.uses_mrope:
|
|
||||||
# M-RoPE need (3, max_num_tokens)
|
|
||||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device)
|
|
||||||
else:
|
|
||||||
# RoPE need (max_num_tokens,)
|
|
||||||
self.positions = torch.zeros(self.max_num_tokens,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device)
|
|
||||||
self.hidden_states = torch.zeros(
|
|
||||||
(self.max_num_tokens, self.hidden_size),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=device)
|
|
||||||
self.full_indices = range(
|
|
||||||
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
|
||||||
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
|
||||||
|
|
||||||
# We need +1 here because the arange is used to set query_start_loc,
|
|
||||||
# which has one more element than batch_size.
|
|
||||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
|
||||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
|
||||||
self.arange = torch.arange(max_num_slots_for_arange,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
|
||||||
(self.max_num_tokens, self.hidden_size),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
self.backup_next_token_ids = CpuGpuBuffer(
|
|
||||||
max_batch_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=is_pin_memory_available(),
|
|
||||||
device=device,
|
|
||||||
with_numpy=True,
|
|
||||||
)
|
|
||||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
|
||||||
"index_topk")
|
|
||||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
|
||||||
|
|
||||||
def load_model(self, model) -> None:
|
def load_model(self, model) -> None:
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user