[CI] Fix broken CI (#6599)
Revert4fb3d5e1b2it breaks E2E Test - vLLM version: v0.15.0 - vLLM main:d7e17aaacd
This commit is contained in:
@@ -1,25 +1,24 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
from typing import Optional, Tuple
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
residual: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
@@ -27,7 +26,8 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, "Currently, this situation only occurs when sp is enabled"
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -38,7 +38,10 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
|
||||
return residual
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
@@ -56,20 +59,24 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
result = torch.empty(
|
||||
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
result[offset : offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
||||
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
@@ -87,44 +94,63 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
else:
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
num_tokens_across_dp_cpu = \
|
||||
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty(
|
||||
(dp_size, forward_context.padded_length, *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
padded_x[idx, :num_tokens_dp] = x[offset : offset + num_tokens_dp]
|
||||
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
|
||||
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), 0)
|
||||
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
|
||||
0)
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
|
||||
if get_forward_context().sp_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None:
|
||||
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
weight_prefetch_stream.wait_stream(calculation_stream)
|
||||
with npu_stream_switch(weight_prefetch_stream):
|
||||
maybe_npu_prefetch(inputs=weight, dependency=start_flag, max_size=max_weight_size)
|
||||
maybe_npu_prefetch(inputs=weight,
|
||||
dependency=start_flag,
|
||||
max_size=max_weight_size)
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl_fake(weight: torch.Tensor, start_flag: torch.Tensor, max_weight_size: int) -> None:
|
||||
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
||||
start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -138,16 +164,20 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} or forward_context.sp_enabled:
|
||||
if moe_comm_type in {
|
||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor:
|
||||
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.custom_op is not None
|
||||
@@ -157,15 +187,16 @@ def _matmul_and_reduce_impl(input_parallel: torch.Tensor, layer_name: str) -> to
|
||||
return output
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str) -> torch.Tensor:
|
||||
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(
|
||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||
)
|
||||
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
|
||||
device=input_parallel.device,
|
||||
dtype=input_parallel.dtype)
|
||||
|
||||
return output
|
||||
|
||||
@@ -176,96 +207,77 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
||||
# pass input_scale and input_scale_reciprocal at the same time to avoid redundant
|
||||
# reciprocal calculation in fussion pass. We shall remove this once
|
||||
# aclnnAddRmsNormQuantV2 supports div_moe=False.
|
||||
def _quantize_impl(
|
||||
in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _quantize_impl_fake(
|
||||
in_tensor: torch.Tensor, input_scale: torch.Tensor, input_scale_reciprocal: torch.Tensor, input_offset: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal, input_offset, torch.qint8, -1, False)
|
||||
def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
def _rope_forward_triton_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
rope_dim: int = -1,
|
||||
is_neox_style: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
is_neox_style: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=_maybe_pad_and_reduce_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="prefetch_postprocess",
|
||||
op_func=_prefetch_postprocess_impl,
|
||||
fake_impl=_prefetch_postprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="prefetch_postprocess",
|
||||
op_func=_prefetch_postprocess_impl,
|
||||
fake_impl=_prefetch_postprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="matmul_and_reduce",
|
||||
op_func=_matmul_and_reduce_impl,
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="matmul_and_reduce",
|
||||
op_func=_matmul_and_reduce_impl,
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="quantize",
|
||||
op_func=_quantize_impl,
|
||||
fake_impl=_quantize_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward_triton",
|
||||
op_func=rope_forward_triton,
|
||||
fake_impl=_rope_forward_triton_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
direct_register_custom_op(op_name="quantize",
|
||||
op_func=_quantize_impl,
|
||||
fake_impl=_quantize_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
direct_register_custom_op(op_name="rope_forward_triton",
|
||||
op_func=rope_forward_triton,
|
||||
fake_impl=_rope_forward_triton_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
Reference in New Issue
Block a user