1008 lines
43 KiB
Python
1008 lines
43 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from copy import copy
|
|
from typing import TYPE_CHECKING, Dict, Optional, List, Tuple, Any
|
|
|
|
import torch
|
|
import numpy as np
|
|
import cnpx
|
|
|
|
from vllm.distributed.parallel_state import (
|
|
get_tp_group, get_pp_group)
|
|
from vllm.distributed.kv_transfer import has_kv_transfer_group, get_kv_transfer_group
|
|
from vllm.distributed import (
|
|
divide, get_moe_expert_parallel_world_size
|
|
)
|
|
from vllm.config import VllmConfig, CUDAGraphMode
|
|
from vllm.forward_context import set_forward_context, BatchDescriptor
|
|
from vllm.logger import init_logger
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.torch_utils import get_dtype_size
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput)
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
|
|
from vllm.v1.utils import record_function_or_nullcontext
|
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
|
from vllm.v1.worker.gpu_model_runner import ExecuteModelState
|
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput, GrammarOutput
|
|
|
|
import vllm_mlu._mlu_utils as mlu_envs
|
|
from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata
|
|
from vllm_mlu.v1.attention.backends.utils import (
|
|
MLUCommonAttentionMetadata, unpad_common_attn_metadata,
|
|
get_common_metadata_from_attn_metadata, MLUInferMode)
|
|
from vllm_mlu.v1.worker.gpu_model_runner import (
|
|
MLUModelRunner, AsyncMLUModelRunnerOutput, apply_grammar_bitmask)
|
|
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
|
from vllm_mlu.model_executor.models.dp_utils import (
|
|
enable_emb_logits_custom_parallel,
|
|
get_runtime_infos_per_dp_group,
|
|
get_deepseek_layer_split_list,
|
|
)
|
|
|
|
from vllm_mlu.model_executor.models.dp_utils import (
|
|
DataParallelRuntimeParams
|
|
)
|
|
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
|
from vllm_mlu.distributed.parallel_state import (
|
|
init_cnclep, get_cnclep
|
|
)
|
|
|
|
from vllm_mlu._mlu_utils import *
|
|
import vllm_mlu._mlu_utils as mlu_envs
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DPMLUModelRunner(MLUModelRunner):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
vllm_config.mlu_config.enable_custom_data_parallel_opt = True
|
|
super().__init__(vllm_config, device)
|
|
self.use_cuda_graph = (
|
|
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
and not self.model_config.enforce_eager)
|
|
if not self.use_cuda_graph and not self.model_config.enforce_eager:
|
|
logger.warning("Can not use cudagraph for dp mlu model runner. Dp mlu model runner can "
|
|
"only support cudagraph_mode with CUDAGraphMode.FULL_DECODE_ONLY.")
|
|
self.use_all2all = self.mlu_config.decode_dispatch_combine_use_all2all
|
|
if self.use_all2all:
|
|
assert get_moe_expert_parallel_world_size() > 1, (
|
|
"all2all requires that expert parallel is enabled")
|
|
kwargs = self.make_cnclep_kwargs()
|
|
init_cnclep(**kwargs)
|
|
if self.model_config.is_longcat_flash:
|
|
kwargs_bf16 = self.make_cnclep_kwargs(use_quant_dispatch=False)
|
|
init_cnclep(**kwargs_bf16)
|
|
self.dp_metadata = None
|
|
|
|
def _get_data_parallel_metadata(
|
|
self,
|
|
num_tokens: int,
|
|
num_reqs: int,
|
|
is_decode_only: bool,
|
|
query_len_per_batch: Optional[List[int]],
|
|
) -> "MLUDPMetadata":
|
|
(dp_query_lens, dp_group_bs, dp_is_prefill,
|
|
seq_len_per_batch) = get_runtime_infos_per_dp_group(
|
|
num_tokens,
|
|
num_reqs,
|
|
not is_decode_only,
|
|
query_len_per_batch,
|
|
self.device,
|
|
self.vllm_config,
|
|
)
|
|
(emb_query_lens, logits_batch_sizes,
|
|
dense_attn_token_split_list) = get_deepseek_layer_split_list(
|
|
dp_query_lens,
|
|
dp_group_bs,
|
|
)
|
|
return MLUDPMetadata.make_oot(
|
|
self.parallel_config.data_parallel_rank,
|
|
self.parallel_config.data_parallel_size,
|
|
self.parallel_config.tensor_parallel_size,
|
|
dp_query_lens,
|
|
dp_is_prefill,
|
|
self.vllm_config.mlu_config.prefill_dispatch_use_RS_AG,
|
|
seq_lens=(seq_len_per_batch if all(dp_is_prefill) else None),
|
|
batch_sizes=dp_group_bs,
|
|
emb_query_lens=emb_query_lens,
|
|
logits_batch_sizes=logits_batch_sizes,
|
|
dense_attn_token_split_list=dense_attn_token_split_list,
|
|
)
|
|
|
|
def _get_dp_graph_info(self,
|
|
K: int,
|
|
num_scheduled_tokens: int,
|
|
dp_metadata: "MLUDPMetadata"):
|
|
"""
|
|
Check if the DeepSeek model can enter graph mode and retrieve input
|
|
tokens and batch.
|
|
|
|
This function also applies to other eligible MoE models with DP enabled,
|
|
reusing the same graph mode compatibility logic.
|
|
|
|
Returns:
|
|
tuple: Contains three elements:
|
|
num_input_tokens: Retrieved input token
|
|
num_input_batchs: Retrieved input batch
|
|
use_graph: Whether the model can use graph mode
|
|
"""
|
|
if (self.use_cuda_graph
|
|
and all(not prefill for prefill in dp_metadata.dp_is_prefill)
|
|
and all(token_num <= self.cudagraph_batch_sizes[-1]
|
|
for token_num in dp_metadata.token_split_list)):
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
|
max(dp_metadata.token_split_list))
|
|
assert num_input_tokens % (K + 1) == 0, \
|
|
f"num_input_tokens ({num_input_tokens}) must be divisible by (K + 1) = {K + 1}"
|
|
num_input_batchs = num_input_tokens // (1 + K)
|
|
use_graph = True
|
|
else:
|
|
num_input_batchs = self.input_batch.num_reqs
|
|
num_input_tokens = num_scheduled_tokens
|
|
use_graph = False
|
|
return num_input_tokens, num_input_batchs, use_graph
|
|
|
|
@torch.inference_mode()
|
|
def moe_dp_execute_dummy_batch(
|
|
self, num_tokens: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
|
|
# MUST do comm across dp group first when enable data parallel.
|
|
# Here we set dummy run state as prefill only to prevent other dp
|
|
# group use graph.
|
|
dp_metadata = self._get_data_parallel_metadata(
|
|
num_tokens, num_reqs, False, [num_tokens // num_reqs] * num_reqs
|
|
)
|
|
|
|
# always skip attn compute
|
|
attn_metadata: Optional[Dict[str, Any]] = None
|
|
|
|
input_ids = self.input_ids.gpu[:num_tokens]
|
|
positions = self.positions.gpu[:num_tokens]
|
|
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
batch_descriptor=None):
|
|
hidden_states = self._model_forward(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=None,
|
|
inputs_embeds=None,
|
|
dp_params=dp_metadata,
|
|
)
|
|
|
|
kwargs = ({"dp_params": dp_metadata}
|
|
if enable_emb_logits_custom_parallel() else {})
|
|
self.model.compute_logits(
|
|
hidden_states[:num_tokens], **kwargs)
|
|
|
|
if self.speculative_config and self.speculative_config.use_eagle():
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
target_token_ids = input_ids
|
|
target_positions = positions
|
|
# hidden_states no need to be sliced
|
|
target_hidden_states = hidden_states
|
|
self.drafter.propose_ds_execute_dummy_batch(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
dp_params=dp_metadata)
|
|
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
logit_indices_device = torch.from_numpy(logit_indices).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
return hidden_states, hidden_states[logit_indices_device]
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
) -> ModelRunnerOutput | IntermediateTensors | None:
|
|
if self.execute_model_state is not None:
|
|
raise RuntimeError(
|
|
"State error: sample_tokens() must be called "
|
|
"after execute_model() returns None."
|
|
)
|
|
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
with record_function_or_nullcontext("dp_gpu_model_runner: preprocess"):
|
|
with self.synchronize_input_prep():
|
|
# Update persistent batch states.
|
|
self._update_states(scheduler_output)
|
|
|
|
if not num_scheduled_tokens:
|
|
if not has_kv_transfer_group():
|
|
# Return empty ModelRunnerOutput if no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
return self.kv_connector_no_forward(
|
|
scheduler_output, self.vllm_config
|
|
)
|
|
if self.cache_config.kv_sharing_fast_prefill:
|
|
assert not self.input_batch.num_prompt_logprobs, (
|
|
"--kv-sharing-fast-prefill produces incorrect "
|
|
"logprobs for prompt tokens, tokens, please disable "
|
|
"it when the requests need prompt logprobs"
|
|
)
|
|
|
|
num_reqs = self.input_batch.num_reqs
|
|
req_ids = self.input_batch.req_ids
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add mlu_infer_mode.
|
|
@brief: prepare mlu dp metadata in _prepare_inputs instead of ubatch_slices
|
|
and num_tokens_across_dp.
|
|
'''
|
|
max_computed_tokens = np.max(self.input_batch.num_computed_tokens_cpu[:num_reqs])
|
|
self.mlu_infer_mode = MLUInferMode.build(
|
|
max_query_len=max_num_scheduled_tokens,
|
|
max_computed_tokens=max_computed_tokens,
|
|
uniform_decode_query_len=self.uniform_decode_query_len,
|
|
)
|
|
|
|
num_tokens_across_dp = None
|
|
(
|
|
logits_indices,
|
|
spec_decode_metadata,
|
|
ubatch_slices,
|
|
dp_metadata,
|
|
) = self._prepare_inputs(
|
|
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
|
|
)
|
|
self.dp_metadata = dp_metadata
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
cascade_attn_prefix_lens = None
|
|
# Disable cascade attention when using microbatching (DBO)
|
|
if self.cascade_attn_enabled and ubatch_slices is None:
|
|
# Pre-compute cascade attention prefix lengths
|
|
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
|
|
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
|
num_scheduled_tokens_np,
|
|
scheduler_output.num_common_prefix_blocks,
|
|
)
|
|
|
|
# TODO(lucas): move cudagraph dispatching here:
|
|
# https://github.com/vllm-project/vllm/issues/23789
|
|
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
attn_metadata, spec_decode_common_attn_metadata = (
|
|
self._build_attention_metadata(
|
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
|
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
|
num_reqs=num_reqs,
|
|
ubatch_slices=ubatch_slices,
|
|
logits_indices=logits_indices,
|
|
use_spec_decode=use_spec_decode,
|
|
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
|
|
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
|
|
mlu_infer_mode=self.mlu_infer_mode,
|
|
)
|
|
)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: pad attn metadata for mlu grpah.
|
|
@brief: pad num_input_tokens based on all dp groups and spec decode.
|
|
@brief: add dp_params to model_kwargs.
|
|
'''
|
|
dp_can_use_graph = False
|
|
if self.use_cuda_graph:
|
|
num_input_tokens_dp, num_padded_reqs, dp_can_use_graph = self._get_dp_graph_info(
|
|
self.num_spec_tokens, num_scheduled_tokens, dp_metadata)
|
|
if dp_can_use_graph:
|
|
# all layers share the same attn_metadata
|
|
assert len(self.kv_cache_config.kv_cache_groups) == 1
|
|
attn_metadata_val = next(iter(attn_metadata.values()))
|
|
common_metadata = get_common_metadata_from_attn_metadata(attn_metadata)
|
|
block_table = self.input_batch.block_table[0]
|
|
pad_attn_metadata(
|
|
attn_metadata_val, common_metadata, block_table, self,
|
|
num_scheduled_tokens, num_input_tokens_dp, num_reqs, num_padded_reqs)
|
|
|
|
dp_rank = self.parallel_config.data_parallel_rank
|
|
if ubatch_slices:
|
|
assert num_tokens_across_dp is not None
|
|
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
|
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
|
elif num_tokens_across_dp is not None:
|
|
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
|
else:
|
|
num_input_tokens = (
|
|
num_input_tokens_dp if dp_can_use_graph else num_scheduled_tokens)
|
|
|
|
(
|
|
input_ids,
|
|
inputs_embeds,
|
|
positions,
|
|
intermediate_tensors,
|
|
model_kwargs,
|
|
ec_connector_output,
|
|
) = self._preprocess(
|
|
scheduler_output, num_input_tokens, intermediate_tensors
|
|
)
|
|
|
|
model_kwargs["dp_params"] = dp_metadata
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
uniform_decode = (
|
|
max_num_scheduled_tokens == self.uniform_decode_query_len
|
|
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
|
batch_descriptor = BatchDescriptor(
|
|
num_tokens=num_input_tokens,
|
|
uniform_decode=uniform_decode,
|
|
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
|
)
|
|
cudagraph_runtime_mode, batch_descriptor = (
|
|
self.cudagraph_dispatcher.dispatch(
|
|
batch_descriptor,
|
|
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
|
)
|
|
)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: check if we can use cudagraph using dp_can_use_graph.
|
|
'''
|
|
if not dp_can_use_graph:
|
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
batch_descriptor = None
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
# Set cudagraph mode to none if calc_kv_scales is true.
|
|
# KV scales calculation involves dynamic operations that are incompatible
|
|
# with CUDA graph capture.
|
|
if self.calculate_kv_scales:
|
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
# Mark KV scales as calculated after the first forward pass
|
|
self.calculate_kv_scales = False
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: debug disagg cnpx.
|
|
'''
|
|
if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE:
|
|
self.execute_cnpx_mark = cnpx.rangeStart("DP_" + str(self.parallel_config.data_parallel_rank) + "_TP_" \
|
|
+ str(get_tensor_model_parallel_rank()) + "_execute_model" + \
|
|
("_no_graph" if cudagraph_runtime_mode == CUDAGraphMode.NONE else ""))
|
|
if mlu_envs.VLLM_DISAGG_CNPX_REQUEST:
|
|
self.request_cnpx_mark.clear()
|
|
for req in scheduler_output.scheduled_new_reqs:
|
|
self.request_cnpx_mark[req.req_id] = cnpx.rangeStart(req.req_id)
|
|
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
|
|
self.request_cnpx_mark[req_id] = cnpx.rangeStart(req_id)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
start = torch.mlu.Event(enable_timing=True)
|
|
start.record()
|
|
|
|
# Run the model.
|
|
# Use persistent buffers for CUDA graphs.
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
ubatch_slices=ubatch_slices,
|
|
),
|
|
record_function_or_nullcontext("dp_gpu_model_runner: forward"),
|
|
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
|
):
|
|
model_output = self._model_forward(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
|
|
with record_function_or_nullcontext("dp_gpu_model_runner: postprocess"):
|
|
if self.use_aux_hidden_state_outputs:
|
|
# True when EAGLE 3 is used.
|
|
hidden_states, aux_hidden_states = model_output
|
|
else:
|
|
# Common case.
|
|
hidden_states = model_output
|
|
aux_hidden_states = None
|
|
|
|
if not self.broadcast_pp_output:
|
|
# Common case.
|
|
if not get_pp_group().is_last_rank:
|
|
# Return the intermediate tensors.
|
|
assert isinstance(hidden_states, IntermediateTensors)
|
|
hidden_states.kv_connector_output = kv_connector_output
|
|
return hidden_states
|
|
|
|
if self.is_pooling_model:
|
|
# Return the pooling output.
|
|
output = self._pool(
|
|
hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
|
|
)
|
|
output.kv_connector_output = kv_connector_output
|
|
return output
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: support embed logits custom parallel.
|
|
'''
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
logits_kwargs = ({"dp_params": dp_metadata}
|
|
if enable_emb_logits_custom_parallel() else {})
|
|
logits = self.model.compute_logits(sample_hidden_states, **logits_kwargs)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
else:
|
|
# Rare case.
|
|
assert not self.is_pooling_model
|
|
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
if not get_pp_group().is_last_rank:
|
|
all_gather_tensors = {
|
|
"residual": not is_residual_scattered_for_sp(
|
|
self.vllm_config, num_input_tokens
|
|
)
|
|
}
|
|
get_pp_group().send_tensor_dict(
|
|
hidden_states.tensors,
|
|
all_gather_group=get_tp_group(),
|
|
all_gather_tensors=all_gather_tensors,
|
|
)
|
|
logits = None
|
|
else:
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
|
|
model_output_broadcast_data = {}
|
|
if logits is not None:
|
|
model_output_broadcast_data["logits"] = logits.contiguous()
|
|
|
|
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
|
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
|
|
)
|
|
assert model_output_broadcast_data is not None
|
|
logits = model_output_broadcast_data["logits"]
|
|
|
|
self.time_markers = []
|
|
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
end = torch.mlu.Event(enable_timing=True)
|
|
end.record()
|
|
self.time_markers.append([start, end])
|
|
|
|
self.execute_model_state = ExecuteModelState(
|
|
scheduler_output,
|
|
logits,
|
|
spec_decode_metadata,
|
|
spec_decode_common_attn_metadata,
|
|
hidden_states,
|
|
sample_hidden_states,
|
|
aux_hidden_states,
|
|
kv_connector_output,
|
|
)
|
|
return None
|
|
|
|
@torch.inference_mode
|
|
def sample_tokens(
|
|
self, grammar_output: "GrammarOutput | None"
|
|
) -> ModelRunnerOutput | AsyncMLUModelRunnerOutput | IntermediateTensors:
|
|
kv_connector_output = self.kv_connector_output
|
|
self.kv_connector_output = None
|
|
|
|
if self.execute_model_state is None:
|
|
# Nothing to do (PP non-final rank case), output isn't used.
|
|
if not kv_connector_output:
|
|
return None # noqa
|
|
|
|
# In case of PP with kv transfer, we need to pass through the
|
|
# kv_connector_output
|
|
if kv_connector_output.is_empty():
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
|
|
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
output.kv_connector_output = kv_connector_output
|
|
return output
|
|
|
|
# Unpack ephemeral state.
|
|
(
|
|
scheduler_output,
|
|
logits,
|
|
spec_decode_metadata,
|
|
spec_decode_common_attn_metadata,
|
|
hidden_states,
|
|
sample_hidden_states,
|
|
aux_hidden_states,
|
|
ec_connector_output,
|
|
) = self.execute_model_state
|
|
# Clear ephemeral state.
|
|
self.execute_model_state = None
|
|
|
|
# Apply structured output bitmasks if present.
|
|
if grammar_output is not None:
|
|
apply_grammar_bitmask(
|
|
scheduler_output, grammar_output, self.input_batch, logits
|
|
)
|
|
|
|
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
|
|
|
self.input_batch.prev_sampled_token_ids = None
|
|
|
|
def propose_draft_token_ids(
|
|
sampled_token_ids: torch.Tensor | list[np.ndarray],
|
|
) -> None:
|
|
assert spec_decode_common_attn_metadata is not None
|
|
with record_function_or_nullcontext("gpu_model_runner: draft"):
|
|
self._draft_token_ids = self.propose_draft_token_ids(
|
|
scheduler_output,
|
|
sampled_token_ids,
|
|
self.input_batch.sampling_metadata,
|
|
hidden_states,
|
|
sample_hidden_states,
|
|
aux_hidden_states,
|
|
spec_decode_metadata,
|
|
spec_decode_common_attn_metadata,
|
|
whole_block_table=self.input_batch.block_table[0],
|
|
main_model_dp_params=self.dp_metadata,
|
|
)
|
|
use_padded_batch_for_eagle = (
|
|
self.speculative_config
|
|
and self.speculative_config.use_eagle()
|
|
and not self.speculative_config.disable_padded_drafter_batch
|
|
)
|
|
effective_drafter_max_model_len = self.max_model_len
|
|
if effective_drafter_max_model_len is None:
|
|
effective_drafter_max_model_len = self.model_config.max_model_len
|
|
if (
|
|
self.speculative_config
|
|
and self.speculative_config.draft_model_config is not None
|
|
and self.speculative_config.draft_model_config.max_model_len is not None
|
|
):
|
|
effective_drafter_max_model_len = (
|
|
self.speculative_config.draft_model_config.max_model_len
|
|
)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Force `input_fits_in_drafter` to be True to ensure that `self.uniform_decode_query_len` tokens are scheduled per batch during model execution.
|
|
This is required for graph validation and to keep the batch token count consistent with `self.uniform_decode_query_len` immediately after the prefill stage.
|
|
'''
|
|
# input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
|
# spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
|
# <= effective_drafter_max_model_len
|
|
# )
|
|
input_fits_in_drafter = True
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if use_padded_batch_for_eagle:
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
if input_fits_in_drafter:
|
|
# EAGLE speculative decoding can use the GPU sampled tokens
|
|
# as inputs, and does not need to wait for bookkeeping to finish.
|
|
propose_draft_token_ids(sampled_token_ids)
|
|
elif self.valid_sampled_token_count_event is not None:
|
|
next_token_ids, valid_sampled_tokens_count = (
|
|
self.drafter.prepare_next_token_ids_padded(
|
|
spec_decode_common_attn_metadata,
|
|
sampled_token_ids,
|
|
self.requests,
|
|
self.input_batch,
|
|
self.discard_request_indices.gpu,
|
|
self.num_discarded_requests,
|
|
)
|
|
)
|
|
self._copy_valid_sampled_token_count(
|
|
next_token_ids, valid_sampled_tokens_count
|
|
)
|
|
|
|
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
|
(
|
|
num_nans_in_logits,
|
|
logprobs_lists,
|
|
valid_sampled_token_ids,
|
|
prompt_logprobs_dict,
|
|
req_ids_output_copy,
|
|
req_id_to_index_output_copy,
|
|
invalid_req_indices,
|
|
) = self._bookkeeping_sync(
|
|
scheduler_output,
|
|
sampler_output,
|
|
logits,
|
|
hidden_states,
|
|
scheduler_output.total_num_scheduled_tokens,
|
|
spec_decode_metadata,
|
|
)
|
|
|
|
if (
|
|
self.speculative_config
|
|
and not use_padded_batch_for_eagle
|
|
and input_fits_in_drafter
|
|
):
|
|
# ngram and other speculative decoding methods use the sampled
|
|
# tokens on the CPU, so they are run after bookkeeping.
|
|
propose_draft_token_ids(valid_sampled_token_ids)
|
|
|
|
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
|
self.eplb_step()
|
|
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
|
output = ModelRunnerOutput(
|
|
req_ids=req_ids_output_copy,
|
|
req_id_to_index=req_id_to_index_output_copy,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
pooler_output=[],
|
|
kv_connector_output=kv_connector_output,
|
|
ec_connector_output=ec_connector_output
|
|
if self.supports_mm_inputs
|
|
else None,
|
|
num_nans_in_logits=num_nans_in_logits,
|
|
)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: supoort disagg for mlu.
|
|
'''
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().wait_for_save()
|
|
get_kv_transfer_group().clear_connector_metadata()
|
|
|
|
if mlu_envs.VLLM_DISAGG_CNPX_EXECUTE:
|
|
current_stream = torch.mlu.current_stream()
|
|
current_stream.synchronize()
|
|
cnpx.rangeEnd(self.execute_cnpx_mark)
|
|
if mlu_envs.VLLM_DISAGG_CNPX_REQUEST:
|
|
current_stream = torch.mlu.current_stream()
|
|
current_stream.synchronize()
|
|
for req in scheduler_output.scheduled_new_reqs:
|
|
cnpx.rangeEnd(self.request_cnpx_mark[req.req_id])
|
|
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
|
|
cnpx.rangeEnd(self.request_cnpx_mark[req_id])
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if not self.use_async_scheduling:
|
|
return output
|
|
with record_function_or_nullcontext(
|
|
"gpu_model_runner: AsyncGPUModelRunnerOutput"
|
|
):
|
|
async_output = AsyncMLUModelRunnerOutput(
|
|
model_runner_output=output,
|
|
sampled_token_ids=sampler_output.sampled_token_ids,
|
|
logprobs_tensors=sampler_output.logprobs_tensors,
|
|
invalid_req_indices=invalid_req_indices,
|
|
async_output_copy_stream=self.async_output_copy_stream,
|
|
vocab_size=self.input_batch.vocab_size,
|
|
)
|
|
with record_function_or_nullcontext(
|
|
"gpu_model_runner: set_async_sampled_token_ids"
|
|
):
|
|
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
|
# any requests with sampling params that require output ids.
|
|
self.input_batch.set_async_sampled_token_ids(
|
|
async_output.sampled_token_ids_cpu,
|
|
async_output.async_copy_ready_event,
|
|
)
|
|
|
|
return async_output
|
|
|
|
def propose_draft_token_ids(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
hidden_states: torch.Tensor,
|
|
sample_hidden_states: torch.Tensor,
|
|
aux_hidden_states: Optional[torch.Tensor],
|
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
common_attn_metadata: MLUCommonAttentionMetadata,
|
|
whole_block_table: torch.Tensor,
|
|
main_model_dp_params: Optional[DataParallelRuntimeParams] = None,
|
|
) -> list[list[int]]:
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: draft model will build new FlashMLAMetadata,
|
|
so just unpad common_attn_metadata here.
|
|
'''
|
|
unpad_common_attn_metadata(
|
|
common_metadata=common_attn_metadata,
|
|
num_reqs=self.input_batch.num_reqs,
|
|
num_scheduled_tokens=num_scheduled_tokens,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if self.speculative_config.method == "ngram":
|
|
assert isinstance(self.drafter, NgramProposer)
|
|
spec_token_ids = self.propose_ngram_draft_token_ids(
|
|
sampled_token_ids)
|
|
elif self.speculative_config.method == "medusa":
|
|
assert isinstance(self.drafter, MedusaProposer)
|
|
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
|
# The input to the target model does not include draft tokens.
|
|
hidden_states = sample_hidden_states
|
|
else:
|
|
indices = []
|
|
offset = 0
|
|
for num_draft, tokens in zip(
|
|
spec_decode_metadata.num_draft_tokens,
|
|
sampled_token_ids):
|
|
indices.append(offset + len(tokens) - 1)
|
|
offset += num_draft + 1
|
|
indices = torch.tensor(indices, device=self.device)
|
|
hidden_states = sample_hidden_states[indices]
|
|
|
|
spec_token_ids = self.drafter.propose(
|
|
target_hidden_states=hidden_states,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
elif self.speculative_config.use_eagle():
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
# TODO(woosuk): Refactor the loop.
|
|
if self.speculative_config.disable_padded_drafter_batch:
|
|
# When padded-batch is disabled, the sampled_token_ids should be
|
|
# the cpu-side list[list[int]] of valid sampled tokens for each
|
|
# request, with invalid requests having empty lists.
|
|
assert isinstance(sampled_token_ids, list), (
|
|
"sampled_token_ids should be a python list when"
|
|
"padded-batch is disabled."
|
|
)
|
|
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
|
sampled_token_ids,
|
|
self.requests,
|
|
self.input_batch,
|
|
scheduler_output.num_scheduled_tokens,
|
|
)
|
|
else:
|
|
# When using padded-batch, the sampled_token_ids should be
|
|
# the gpu tensor of sampled tokens for each request, of shape
|
|
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
|
# value -1.
|
|
assert isinstance(sampled_token_ids, torch.Tensor), (
|
|
"sampled_token_ids should be a torch.Tensor when"
|
|
"padded-batch is enabled."
|
|
)
|
|
next_token_ids, valid_sampled_tokens_count = (
|
|
self.drafter.prepare_next_token_ids_padded(
|
|
common_attn_metadata,
|
|
sampled_token_ids,
|
|
self.requests,
|
|
self.input_batch,
|
|
self.discard_request_indices.gpu,
|
|
self.num_discarded_requests,
|
|
)
|
|
)
|
|
self._copy_valid_sampled_token_count(
|
|
next_token_ids, valid_sampled_tokens_count
|
|
)
|
|
|
|
if spec_decode_metadata is None:
|
|
token_indices_to_sample = None
|
|
# input_ids can be None for multimodal models.
|
|
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
|
# TODO(woosuk): Support M-RoPE.
|
|
target_positions = self._get_positions(num_scheduled_tokens)
|
|
if self.use_aux_hidden_state_outputs:
|
|
assert aux_hidden_states is not None
|
|
target_hidden_states = torch.cat(
|
|
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
|
dim=-1)
|
|
else:
|
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
num_rejected_tokens_gpu = None
|
|
token_indices = None
|
|
else:
|
|
if self.speculative_config.disable_padded_drafter_batch:
|
|
token_indices_to_sample = None
|
|
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
|
common_attn_metadata,
|
|
sampled_token_ids,
|
|
spec_decode_metadata.num_draft_tokens,
|
|
)
|
|
else:
|
|
common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = (
|
|
self.drafter.prepare_inputs_padded(
|
|
common_attn_metadata,
|
|
spec_decode_metadata,
|
|
valid_sampled_tokens_count,
|
|
)
|
|
)
|
|
target_token_ids = self.input_ids.gpu[token_indices]
|
|
target_positions = self._get_positions(token_indices)
|
|
if self.use_aux_hidden_state_outputs:
|
|
assert aux_hidden_states is not None
|
|
target_hidden_states = torch.cat(
|
|
[h[token_indices] for h in aux_hidden_states], dim=-1
|
|
)
|
|
else:
|
|
target_hidden_states = hidden_states[token_indices]
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add debug info for draft accepted rate
|
|
'''
|
|
if mlu_envs.VLLM_MTP_DEBUG:
|
|
batch_total_draft = sum(spec_decode_metadata.num_draft_tokens)
|
|
batch_total_rejected = sum(num_rejected_tokens_gpu)
|
|
self.total_draft_tokens += batch_total_draft
|
|
self.total_accepted_tokens += (
|
|
batch_total_draft - batch_total_rejected)
|
|
if batch_total_draft > 0:
|
|
batch_accept_rate = (
|
|
batch_total_draft - batch_total_rejected
|
|
) / batch_total_draft
|
|
print(f"Batch Accept Rate: {batch_accept_rate:.4f}, "
|
|
f"Total Accept Rate: {self.get_accept_rate():.4f}")
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if self.supports_mm_inputs:
|
|
mm_embed_inputs = self._gather_mm_embeddings(
|
|
scheduler_output,
|
|
shift_computed_tokens=1,
|
|
)
|
|
else:
|
|
mm_embed_inputs = None
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: keep full scheduled tokens for draft model compute
|
|
'''
|
|
target_token_ids = target_token_ids[:num_scheduled_tokens]
|
|
target_positions = target_positions[:num_scheduled_tokens]
|
|
target_hidden_states = target_hidden_states[:num_scheduled_tokens]
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
spec_token_ids = self.drafter.propose(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
next_token_ids=next_token_ids,
|
|
last_token_indices=token_indices_to_sample,
|
|
sampling_metadata=sampling_metadata,
|
|
common_attn_metadata=common_attn_metadata,
|
|
num_rejected_tokens=num_rejected_tokens_gpu,
|
|
token_indices=token_indices,
|
|
whole_block_table=whole_block_table,
|
|
main_model_dp_params=main_model_dp_params,
|
|
time_markers=self.time_markers,
|
|
)
|
|
return spec_token_ids
|
|
|
|
def make_cnclep_kwargs(self, use_quant_dispatch: bool = True) -> dict[Any, Any]:
|
|
|
|
K = (self.drafter.num_speculative_tokens
|
|
if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer)
|
|
else 0)
|
|
seq_len = K + 1
|
|
config = self.model_config.hf_config
|
|
num_experts = (config.n_routed_experts if hasattr(config, "n_routed_experts")
|
|
else config.num_experts)
|
|
topk = getattr(config, "num_experts_per_tok", None) or getattr(config, "moe_topk", None)
|
|
assert topk is not None, "failed to get topk from config"
|
|
hidden_size = config.hidden_size
|
|
dispatch_token_size = hidden_size * get_dtype_size(self.dtype)
|
|
if use_quant_dispatch:
|
|
dispatch_token_size = hidden_size * get_dtype_size(torch.int8) + get_dtype_size(torch.float32)
|
|
combine_token_size = hidden_size * get_dtype_size(self.dtype)
|
|
|
|
max_num_seqs_per_dp = self.scheduler_config.max_num_seqs
|
|
# max number of tokens that an ep rank could send
|
|
max_num_tokens_per_rank = divide(max_num_seqs_per_dp * seq_len * topk,
|
|
self.parallel_config.tensor_parallel_size)
|
|
|
|
return dict(dispatch_token_size=dispatch_token_size,
|
|
combine_token_size=combine_token_size,
|
|
max_num_tokens_per_rank=max_num_tokens_per_rank,
|
|
num_global_experts=num_experts,
|
|
use_quant_dispatch=use_quant_dispatch)
|
|
|
|
def prepare_all2all_buffer_for_model(
|
|
self, model: torch.nn.Module) -> None:
|
|
"""
|
|
Prepare all2all buffer for the model.
|
|
"""
|
|
if not self.use_all2all:
|
|
return
|
|
|
|
moe_modules = [
|
|
module for module in self.model.modules()
|
|
if isinstance(module, SparseMoeMlp)
|
|
]
|
|
if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer):
|
|
draft_moes = [
|
|
module for module in self.drafter.model.modules()
|
|
if isinstance(module, SparseMoeMlp) and not mlu_envs.VLLM_MTP_NO_QUANT
|
|
]
|
|
moe_modules.extend(draft_moes)
|
|
for module in moe_modules:
|
|
if self.load_config.load_format == "dummy":
|
|
module.pack_params()
|
|
module.pack_params_after_loading()
|
|
use_quant_dispatch = module.quant_config is not None
|
|
module.prepare_for_cnclep(get_cnclep(use_quant_dispatch=use_quant_dispatch))
|
|
|
|
def load_model(self, eep_scale_up: bool = False) -> None:
|
|
super().load_model()
|
|
if self.use_all2all:
|
|
self.prepare_all2all_buffer_for_model(self.model)
|