### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
372 lines
15 KiB
Python
372 lines
15 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import torch_npu
|
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
tensor_model_parallel_all_reduce,
|
|
tensor_model_parallel_reduce_scatter)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
|
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
|
|
|
|
|
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
|
residual: torch.Tensor) -> torch.Tensor:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return residual
|
|
|
|
if x.size(0) != residual.size(0):
|
|
sp_enabled = forward_context.sp_enabled
|
|
assert sp_enabled is True, ("Currently, this situation only occurs "
|
|
"when sp is enabled")
|
|
pad_size = forward_context.pad_size
|
|
if pad_size > 0:
|
|
residual = F.pad(residual, (0, 0, 0, pad_size))
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
|
|
|
|
return residual
|
|
|
|
|
|
def _maybe_all_gather_and_maybe_unpad_impl(
|
|
x: torch.Tensor,
|
|
label: bool,
|
|
is_ep_comm: bool = False) -> torch.Tensor:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return x
|
|
|
|
sp_enabled = forward_context.sp_enabled
|
|
if sp_enabled and label:
|
|
dp_metadata = forward_context.dp_metadata
|
|
if dp_metadata is None or not is_ep_comm:
|
|
x = tensor_model_parallel_all_gather(x, 0)
|
|
pad_size = forward_context.pad_size
|
|
if pad_size > 0:
|
|
x = x[:-pad_size]
|
|
else:
|
|
x = get_ep_group().all_gather(x, 0)
|
|
# unpad
|
|
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
|
result = torch.empty(
|
|
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
|
device=x.device,
|
|
dtype=x.dtype)
|
|
dp_size = get_dp_group().world_size
|
|
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
|
offset = 0
|
|
for idx in range(dp_size):
|
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
|
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
|
offset += num_tokens_dp
|
|
x = result
|
|
|
|
return x
|
|
|
|
|
|
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
|
is_ep_comm: bool = False) -> torch.Tensor:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return tensor_model_parallel_all_reduce(x)
|
|
|
|
if not getattr(forward_context, "sp_enabled", False):
|
|
return tensor_model_parallel_all_reduce(x)
|
|
|
|
dp_metadata = forward_context.dp_metadata
|
|
if dp_metadata is None or not is_ep_comm:
|
|
pad_size = forward_context.pad_size
|
|
if pad_size > 0:
|
|
x = F.pad(x, (0, 0, 0, pad_size))
|
|
return tensor_model_parallel_reduce_scatter(x, 0)
|
|
else:
|
|
# padding
|
|
dp_size = get_dp_group().world_size
|
|
num_tokens_across_dp_cpu = \
|
|
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
|
padded_x = torch.empty(
|
|
(dp_size, forward_context.padded_length, *x.shape[1:]),
|
|
device=x.device,
|
|
dtype=x.dtype)
|
|
offset = 0
|
|
for idx in range(dp_size):
|
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
|
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
|
|
offset += num_tokens_dp
|
|
|
|
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
|
|
0)
|
|
|
|
|
|
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
|
prefix: str) -> None:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return
|
|
|
|
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
|
|
return
|
|
model_instance = forward_context.model_instance
|
|
weight_prefetch_stream = prefetch_stream()
|
|
layer_idx = int(prefix.split('.')[2])
|
|
|
|
# start point of gate_up_proj weight prefetch
|
|
if prefix.split('.')[-2] == "self_attn":
|
|
forward_context.prefetch_mlp_gate_up_proj = True
|
|
if forward_context.prefetch_mlp_gate_up_proj:
|
|
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
|
|
|
|
with torch.npu.stream(weight_prefetch_stream):
|
|
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
|
torch_npu.npu_prefetch(
|
|
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
|
|
x_dependency, mlp_gate_up_prefetch_size)
|
|
return
|
|
|
|
|
|
def _maybe_all_gather_and_maybe_unpad_fake(
|
|
x: torch.Tensor,
|
|
label: bool,
|
|
is_ep_comm: bool = False) -> torch.Tensor:
|
|
|
|
if get_forward_context().sp_enabled and label:
|
|
return torch.empty(
|
|
(x.shape[0] * get_tensor_model_parallel_world_size(),
|
|
*x.shape[1:]),
|
|
device=x.device,
|
|
dtype=x.dtype)
|
|
|
|
return x
|
|
|
|
|
|
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
|
|
is_ep_comm: bool = False) -> torch.Tensor:
|
|
if get_forward_context().sp_enabled:
|
|
return torch.empty(
|
|
(x.shape[0] // get_tensor_model_parallel_world_size(),
|
|
*x.shape[1:]),
|
|
device=x.device,
|
|
dtype=x.dtype)
|
|
|
|
return x
|
|
|
|
|
|
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
|
prefix: str) -> None:
|
|
return
|
|
|
|
|
|
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return
|
|
|
|
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
|
|
return
|
|
forward_context.prefetch_mlp_down_proj = True
|
|
model_instance = forward_context.model_instance
|
|
weight_prefetch_stream = prefetch_stream()
|
|
layer_idx = forward_context.layer_idx
|
|
|
|
# start point of down_proj weight prefetch
|
|
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
|
|
|
|
with torch.npu.stream(weight_prefetch_stream):
|
|
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
|
torch_npu.npu_prefetch(
|
|
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
|
|
x_dependency, mlp_down_prefetch_size)
|
|
forward_context.layer_idx += 1
|
|
return
|
|
|
|
|
|
def _maybe_prefetch_mlp_down_proj_impl_fake(
|
|
x_dependency: torch.Tensor) -> None:
|
|
return
|
|
|
|
|
|
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
|
try:
|
|
forward_context = get_forward_context()
|
|
except AssertionError:
|
|
return
|
|
|
|
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
|
|
return
|
|
if forward_context.prefetch_mlp_gate_up_proj or \
|
|
forward_context.prefetch_mlp_down_proj:
|
|
weight_prefetch_stream = prefetch_stream()
|
|
# wait until prefetch done
|
|
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
|
|
forward_context.prefetch_mlp_gate_up_proj = False
|
|
forward_context.prefetch_mlp_down_proj = False
|
|
return
|
|
|
|
|
|
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
|
return
|
|
|
|
|
|
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
|
max_weight_size: int) -> None:
|
|
calculation_stream = torch_npu.npu.current_stream()
|
|
weight_prefetch_stream = prefetch_stream()
|
|
weight_prefetch_stream.wait_stream(calculation_stream)
|
|
with npu_stream_switch(weight_prefetch_stream):
|
|
maybe_npu_prefetch(inputs=weight,
|
|
dependency=start_flag,
|
|
max_size=max_weight_size)
|
|
|
|
|
|
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
|
start_flag: torch.Tensor,
|
|
max_weight_size: int) -> None:
|
|
return
|
|
|
|
|
|
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
|
|
calculation_stream = torch_npu.npu.current_stream()
|
|
weight_prefetch_stream = prefetch_stream()
|
|
calculation_stream.wait_stream(weight_prefetch_stream)
|
|
|
|
|
|
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
|
return
|
|
|
|
|
|
def _maybe_all_reduce_tensor_model_parallel_impl(
|
|
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
|
forward_context = get_forward_context()
|
|
moe_comm_type = forward_context.moe_comm_type
|
|
if moe_comm_type in {
|
|
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
|
|
} or forward_context.sp_enabled:
|
|
return final_hidden_states
|
|
else:
|
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
|
|
|
|
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
|
|
layer_name: str) -> torch.Tensor:
|
|
forward_context = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
assert self.custom_op is not None
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
output = self.custom_op.matmul_and_reduce(input_parallel, bias_)
|
|
|
|
return output
|
|
|
|
|
|
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
|
layer_name: str) -> torch.Tensor:
|
|
forward_context = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
num_tokens = input_parallel.size(0)
|
|
if forward_context.sp_enabled:
|
|
num_tokens = num_tokens // self.tp_size
|
|
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
|
|
device=input_parallel.device,
|
|
dtype=input_parallel.dtype)
|
|
|
|
return output
|
|
|
|
|
|
# TODO(Angazenn): The reason why we use a custom op to encapsulate npu_quantize
|
|
# is that aclnnAscendQuantV3(npu_quantize) use div_mode=False, while
|
|
# aclnnAddRmsNormQuantV2(npu_add_rms_norm_quant) use div_moe=True. We have to
|
|
# pass input_scale and input_scale_reciprocal at the same time to avoid redundant
|
|
# reciprocal calculation in fussion pass. We shall remove this once
|
|
# aclnnAddRmsNormQuantV2 supports div_moe=False.
|
|
def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
|
input_scale_reciprocal: torch.Tensor,
|
|
input_offset: torch.Tensor) -> torch.Tensor:
|
|
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
|
input_offset, torch.qint8, -1, False)
|
|
|
|
|
|
def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
|
input_scale_reciprocal: torch.Tensor,
|
|
input_offset: torch.Tensor) -> torch.Tensor:
|
|
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
|
input_offset, torch.qint8, -1, False)
|
|
|
|
|
|
direct_register_custom_op(op_name="maybe_chunk_residual",
|
|
op_func=_maybe_chunk_residual_impl,
|
|
fake_impl=lambda x, residual: x,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
|
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
|
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
|
op_func=_maybe_pad_and_reduce_impl,
|
|
fake_impl=_maybe_pad_and_reduce_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
|
|
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
|
|
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
|
|
op_func=_maybe_prefetch_mlp_down_proj_impl,
|
|
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
|
op_func=_maybe_wait_prefetch_done_impl,
|
|
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="prefetch_preprocess",
|
|
op_func=_prefetch_preprocess_impl,
|
|
fake_impl=_prefetch_preprocess_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="prefetch_postprocess",
|
|
op_func=_prefetch_postprocess_impl,
|
|
fake_impl=_prefetch_postprocess_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
|
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
|
fake_impl=lambda x: x,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="matmul_and_reduce",
|
|
op_func=_matmul_and_reduce_impl,
|
|
fake_impl=_matmul_and_reduce_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|
|
|
|
direct_register_custom_op(op_name="quantize",
|
|
op_func=_quantize_impl,
|
|
fake_impl=_quantize_impl_fake,
|
|
mutates_args=[],
|
|
dispatch_key="PrivateUse1")
|