[Refactor][EAGLE] 2/N: load model and generate token (#5437)
### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file
2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ut and tests
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -1,25 +1,16 @@
|
||||
import importlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@@ -54,15 +45,6 @@ _MTP_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
def _load_model(architecture):
|
||||
if architecture not in _MTP_MODELS:
|
||||
raise ValueError("Invalid architecture for mtp.")
|
||||
module_name, model_name = _MTP_MODELS[architecture]
|
||||
module = importlib.import_module(module_name)
|
||||
model = getattr(module, model_name)
|
||||
return model
|
||||
|
||||
|
||||
class MtpProposer(EagleProposer):
|
||||
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
@@ -86,64 +68,6 @@ class MtpProposer(EagleProposer):
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase).keys())
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache).keys())
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_device = self.vllm_config.device_config.device
|
||||
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
self._init_mtp_model()
|
||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase).keys() -
|
||||
target_attn_layer_names)
|
||||
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache)
|
||||
draft_indexer_layer_names = indexer_layers.keys(
|
||||
) - target_indexer_layer_names
|
||||
# NOTE: Currently we don't have specific attention backend and attention metadata
|
||||
# for deepseek v3.2 indexer, so we just exclude the indexer layers here.
|
||||
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = list(draft_attn_layer_names)
|
||||
|
||||
self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
if self.vllm_config.model_config.is_deepseek_mla:
|
||||
# check if mtp model use main model's embedding and LMhead
|
||||
main_model = model
|
||||
if get_pp_group().world_size == 1:
|
||||
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
|
||||
if torch.equal(self.model.model.embed_tokens.weight,
|
||||
main_model.model.embed_tokens.weight):
|
||||
self.model.model.embed_tokens = main_model.model.embed_tokens
|
||||
for _, layer_module in self.model.model.layers.items():
|
||||
if torch.equal(layer_module.shared_head.head.weight,
|
||||
main_model.lm_head.weight):
|
||||
layer_module.shared_head.head = main_model.lm_head
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
):
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -256,153 +180,6 @@ class MtpProposer(EagleProposer):
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
def generate_token_ids(self,
|
||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), \
|
||||
"sampled_token_ids should be a python list when" \
|
||||
"padded-batch is disabled."
|
||||
next_token_ids = self.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.runner.requests,
|
||||
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
|
||||
else:
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||
"sampled_token_ids should be a torch.Tensor when" \
|
||||
"padded-batch is enabled."
|
||||
next_token_ids, valid_sampled_tokens_count = \
|
||||
self.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.runner.requests,
|
||||
self.runner.input_batch,
|
||||
self.runner.discard_request_indices.gpu,
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
else:
|
||||
long_seq_metadata = None
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids.gpu[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self._prepare_inputs(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
spec_decode_metadata.num_draft_tokens)
|
||||
else:
|
||||
common_attn_metadata, token_indices, \
|
||||
token_indices_to_sample =\
|
||||
self.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
|
||||
draft_token_ids = self._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
||||
if self.runner.valid_sampled_token_count_event is not None:
|
||||
default_stream = torch.npu.current_stream()
|
||||
# initialize a new stream to overlap the copy operation with
|
||||
# prepare_input of draft model.
|
||||
with torch.npu.stream(
|
||||
self.runner.valid_sampled_token_count_copy_stream):
|
||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
||||
default_stream) # type: ignore
|
||||
self.runner.valid_sampled_token_count_cpu[:
|
||||
valid_sampled_tokens_count
|
||||
.shape[0]].copy_(
|
||||
valid_sampled_tokens_count,
|
||||
non_blocking=True
|
||||
)
|
||||
self.runner.valid_sampled_token_count_event.record()
|
||||
|
||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
||||
1)
|
||||
|
||||
def _init_mtp_model(self):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
target_device = self.vllm_config.device_config.device
|
||||
model = _load_model(architecture)
|
||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
Reference in New Issue
Block a user