[main] flashcomm_v1 optim in Qwen Dense Models (#2802)
### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4
Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from modelscope import snapshot_download # type: ignore
|
from modelscope import snapshot_download # type: ignore
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
@@ -30,6 +31,8 @@ from tests.e2e.conftest import VllmRunner
|
|||||||
|
|
||||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||||
|
|
||||||
|
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
|
||||||
|
|
||||||
|
|
||||||
def test_models_distributed_QwQ():
|
def test_models_distributed_QwQ():
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
@@ -150,3 +153,23 @@ def test_sp_for_qwen3_moe() -> None:
|
|||||||
enable_expert_parallel=True,
|
enable_expert_parallel=True,
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
vllm_model.generate(example_prompts, sampling_params)
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||||
|
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
|
||||||
|
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download(model),
|
||||||
|
max_model_len=8192,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|||||||
@@ -10,6 +10,13 @@ def dummy_tensor():
|
|||||||
return torch.randn(4, 8, dtype=torch.float16)
|
return torch.randn(4, 8, dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_maybe_chunk_residual(x, residual):
|
||||||
|
if x.size(0) != residual.size(0):
|
||||||
|
return residual[:4]
|
||||||
|
|
||||||
|
return residual
|
||||||
|
|
||||||
|
|
||||||
def mock_rms_norm(x, weight, eps):
|
def mock_rms_norm(x, weight, eps):
|
||||||
return x + 1, None
|
return x + 1, None
|
||||||
|
|
||||||
@@ -23,11 +30,13 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
|||||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
|
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||||
residual, dummy_tensor):
|
side_effect=mock_maybe_chunk_residual)
|
||||||
|
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||||
|
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
|
||||||
|
|
||||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||||
layer = RMSNorm(hidden_size=32, eps=1e-05)
|
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||||
|
|
||||||
@@ -51,3 +60,25 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
|
|||||||
|
|
||||||
mock_rmsnorm.assert_called_once()
|
mock_rmsnorm.assert_called_once()
|
||||||
assert torch.allclose(out_x, expected_out_x)
|
assert torch.allclose(out_x, expected_out_x)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||||
|
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||||
|
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||||
|
side_effect=mock_maybe_chunk_residual)
|
||||||
|
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||||
|
mock_add_rms_norm, mock_is310p):
|
||||||
|
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||||
|
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||||
|
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||||
|
|
||||||
|
out_x, out_residual = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
|
expected_out_x = 2 * x
|
||||||
|
expected_out_residual = 2 * residual[:4]
|
||||||
|
|
||||||
|
mock_maybe_chunk_residual.assert_called_once()
|
||||||
|
mock_add_rms_norm.assert_called_once()
|
||||||
|
assert out_residual.size(0) == 4
|
||||||
|
assert torch.allclose(out_x, expected_out_x)
|
||||||
|
assert torch.allclose(out_residual, expected_out_residual)
|
||||||
|
|||||||
@@ -303,13 +303,13 @@ class TestUtils(TestBase):
|
|||||||
# ascend custom op is not registered
|
# ascend custom op is not registered
|
||||||
utils.register_ascend_customop()
|
utils.register_ascend_customop()
|
||||||
# should call register_oot three
|
# should call register_oot three
|
||||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
self.assertEqual(mock_customop.register_oot.call_count, 13)
|
||||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||||
|
|
||||||
# ascend custom op is already registered
|
# ascend custom op is already registered
|
||||||
utils.register_ascend_customop()
|
utils.register_ascend_customop()
|
||||||
# should not register_oot again, thus only called three in this ut
|
# should not register_oot again, thus only called three in this ut
|
||||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
self.assertEqual(mock_customop.register_oot.call_count, 13)
|
||||||
|
|
||||||
|
|
||||||
class TestProfileExecuteDuration(TestBase):
|
class TestProfileExecuteDuration(TestBase):
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ def set_ascend_forward_context(
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
||||||
forward_context.with_prefill = with_prefill
|
forward_context.with_prefill = with_prefill
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
ep_size = (get_ep_group().world_size if
|
ep_size = (get_ep_group().world_size if
|
||||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||||
|
|
||||||
@@ -103,6 +104,21 @@ def set_ascend_forward_context(
|
|||||||
# due to multiple warmups before actual capturing
|
# due to multiple warmups before actual capturing
|
||||||
forward_context.capturing = False
|
forward_context.capturing = False
|
||||||
|
|
||||||
|
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature.
|
||||||
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||||
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||||
|
# the performance may degrade due to the switching of communication methods.
|
||||||
|
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
|
||||||
|
tp_world_size > 1 and \
|
||||||
|
num_tokens is not None and num_tokens > 1000
|
||||||
|
|
||||||
|
if flashcomm_v1_enabled:
|
||||||
|
pad_size = (tp_world_size -
|
||||||
|
(num_tokens % tp_world_size)) % tp_world_size
|
||||||
|
forward_context.pad_size = pad_size
|
||||||
|
|
||||||
|
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
|
||||||
|
|
||||||
if num_tokens is None and attn_metadata is not None:
|
if num_tokens is None and attn_metadata is not None:
|
||||||
num_tokens = attn_metadata.num_actual_tokens
|
num_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
@@ -118,7 +134,6 @@ def set_ascend_forward_context(
|
|||||||
if num_tokens is not None:
|
if num_tokens is not None:
|
||||||
if num_actual_tokens is None:
|
if num_actual_tokens is None:
|
||||||
num_actual_tokens = num_tokens
|
num_actual_tokens = num_tokens
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# NOTE: token num which need to pad to when mc2
|
# NOTE: token num which need to pad to when mc2
|
||||||
forward_context.padded_num_tokens = math.ceil(
|
forward_context.padded_num_tokens = math.ceil(
|
||||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||||
|
|||||||
@@ -131,6 +131,15 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# this feature is supported in A2, and eager mode will get better performance.
|
# this feature is supported in A2, and eager mode will get better performance.
|
||||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
|
||||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||||
|
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
||||||
|
# This feature will get better performance when concurrency is large.
|
||||||
|
"VLLM_ASCEND_ENABLE_FLASHCOMM":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
|
||||||
|
# Whether to enable dense model and general optimizations for better performance.
|
||||||
|
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
||||||
|
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
||||||
|
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
|
||||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||||
# this feature in eager mode will get better performance.
|
# this feature in eager mode will get better performance.
|
||||||
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import torch
|
|||||||
import vllm_ascend.ops.common_fused_moe # noqa
|
import vllm_ascend.ops.common_fused_moe # noqa
|
||||||
import vllm_ascend.ops.fused_moe # noqa
|
import vllm_ascend.ops.fused_moe # noqa
|
||||||
import vllm_ascend.ops.layernorm # noqa
|
import vllm_ascend.ops.layernorm # noqa
|
||||||
|
import vllm_ascend.ops.register_custom_ops # noqa
|
||||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||||
from vllm_ascend.ops.rotary_embedding import (
|
from vllm_ascend.ops.rotary_embedding import (
|
||||||
|
|||||||
@@ -44,6 +44,13 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||||
|
# is enabled, without interfering with the normal operation of components like torchair.
|
||||||
|
# The final solution should be to move this check into the operator and support
|
||||||
|
# integration with torchair.
|
||||||
|
if x.size(0) != residual.size(0):
|
||||||
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
|
assert x.size(0) == residual.size(0)
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||||
x,
|
x,
|
||||||
residual,
|
residual,
|
||||||
@@ -69,6 +76,13 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
|
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.utils import is_310p
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||||
|
# is enabled, without interfering with the normal operation of components like torchair.
|
||||||
|
# The final solution should be to move this check into the operator and support
|
||||||
|
# integration with torchair.
|
||||||
|
if x.size(0) != residual.size(0):
|
||||||
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
|
assert x.size(0) == residual.size(0)
|
||||||
if is_310p():
|
if is_310p():
|
||||||
orig_dtype = residual.dtype
|
orig_dtype = residual.dtype
|
||||||
x = x + residual.to(x.dtype)
|
x = x + residual.to(x.dtype)
|
||||||
|
|||||||
@@ -26,20 +26,18 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.distributed import divide, split_tensor_along_last_dim
|
from vllm.distributed import divide, split_tensor_along_last_dim
|
||||||
from vllm.distributed.parallel_state import get_tp_group
|
from vllm.distributed.parallel_state import get_tp_group
|
||||||
from vllm.lora.utils import LinearBase
|
from vllm.lora.utils import LinearBase
|
||||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
from vllm.model_executor.layers.linear import ( # noqa
|
||||||
ColumnParallelLinear,
|
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||||
QuantizeMethodBase,
|
RowParallelLinear, UnquantizedLinearMethod)
|
||||||
RowParallelLinear,
|
|
||||||
UnquantizedLinearMethod)
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import \
|
from vllm.model_executor.layers.quantization.base_config import \
|
||||||
QuantizationConfig
|
QuantizationConfig
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
|
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
|
||||||
oproj_tp_enable)
|
mlp_tp_enable, oproj_tp_enable)
|
||||||
|
|
||||||
_HCOMM_INFO = None
|
_HCOMM_INFO = None
|
||||||
|
|
||||||
@@ -150,6 +148,9 @@ class AscendRowParallelLinear(RowParallelLinear):
|
|||||||
comm_group = get_tp_group()
|
comm_group = get_tp_group()
|
||||||
self.forward_type = "matmul_allreduce"
|
self.forward_type = "matmul_allreduce"
|
||||||
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
|
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
|
||||||
|
elif dense_optim_enable():
|
||||||
|
comm_group = get_tp_group()
|
||||||
|
self.forward_type = "dense_optim"
|
||||||
else:
|
else:
|
||||||
comm_group = get_tp_group()
|
comm_group = get_tp_group()
|
||||||
self.forward_type = "normal"
|
self.forward_type = "normal"
|
||||||
@@ -231,6 +232,8 @@ class AscendRowParallelLinear(RowParallelLinear):
|
|||||||
return self._forward_mlp_tp(input_)
|
return self._forward_mlp_tp(input_)
|
||||||
elif self.forward_type == "matmul_allreduce":
|
elif self.forward_type == "matmul_allreduce":
|
||||||
return self._forward_matmul_allreduce(input_)
|
return self._forward_matmul_allreduce(input_)
|
||||||
|
elif self.forward_type == "dense_optim":
|
||||||
|
return self._forward_dense_optim(input_)
|
||||||
else:
|
else:
|
||||||
return super().forward(input_)
|
return super().forward(input_)
|
||||||
|
|
||||||
@@ -332,6 +335,39 @@ class AscendRowParallelLinear(RowParallelLinear):
|
|||||||
return output
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
def _forward_dense_optim(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||||
|
communication-computation fusion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.input_is_parallel:
|
||||||
|
input_parallel = input_
|
||||||
|
else:
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.tp_size)
|
||||||
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||||
|
|
||||||
|
assert self.quant_method is not None
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
|
||||||
|
if self.tp_size == 1 or not self.reduce_results:
|
||||||
|
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
|
else:
|
||||||
|
output_parallel = self.quant_method.apply(self,
|
||||||
|
input_parallel,
|
||||||
|
bias=bias_)
|
||||||
|
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||||
"""Packed linear layers with column parallelism.
|
"""Packed linear layers with column parallelism.
|
||||||
@@ -357,15 +393,18 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
|||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
):
|
):
|
||||||
self.comm_group = None
|
|
||||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||||
self.comm_group = get_mlp_tp_group()
|
comm_group = get_mlp_tp_group()
|
||||||
self.forward_type = "mlp_tp"
|
self.forward_type = "mlp_tp"
|
||||||
|
elif dense_optim_enable():
|
||||||
|
comm_group = get_tp_group()
|
||||||
|
self.forward_type = "dense_optim"
|
||||||
else:
|
else:
|
||||||
self.comm_group = get_tp_group()
|
comm_group = get_tp_group()
|
||||||
self.forward_type = "normal_tp"
|
self.forward_type = "normal_tp"
|
||||||
self.tp_rank = self.comm_group.rank_in_group
|
self.comm_group = comm_group
|
||||||
self.tp_size = self.comm_group.world_size
|
self.tp_rank = comm_group.rank_in_group
|
||||||
|
self.tp_size = comm_group.world_size
|
||||||
|
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
assert all(output_size % self.tp_size == 0
|
assert all(output_size % self.tp_size == 0
|
||||||
@@ -387,6 +426,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
|||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
if self.forward_type == "mlp_tp":
|
if self.forward_type == "mlp_tp":
|
||||||
return self._forward_mlp_tp(input_)
|
return self._forward_mlp_tp(input_)
|
||||||
|
elif self.forward_type == "dense_optim":
|
||||||
|
return self._forward_dense_optim(input_)
|
||||||
else:
|
else:
|
||||||
return super().forward(input_)
|
return super().forward(input_)
|
||||||
|
|
||||||
@@ -405,6 +446,138 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
|||||||
return output
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
def _forward_dense_optim(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||||
|
communication-computation fusion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||||
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
|
|
||||||
|
if self.gather_output:
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = self.comm_group.all_gather(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
class AscendQKVParallelLinear(QKVParallelLinear):
|
||||||
|
"""Linear layers for the attention's QKV transformation.
|
||||||
|
|
||||||
|
Linear layers for the linear transformation of the query, key, and value
|
||||||
|
vectors in the attention layer. The weight matrix is concatenated along
|
||||||
|
the output dimension. The layer is parallelized along the head dimension.
|
||||||
|
When the number of key/value heads is smaller than the number of query
|
||||||
|
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||||
|
be replicated while the query heads are partitioned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
head_size: int,
|
||||||
|
total_num_heads: int,
|
||||||
|
total_num_kv_heads: Optional[int] = None,
|
||||||
|
bias: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
|
if dense_optim_enable():
|
||||||
|
self.forward_type = "dense_optim"
|
||||||
|
else:
|
||||||
|
self.forward_type = "normal_tp"
|
||||||
|
self.comm_group = get_tp_group()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = head_size
|
||||||
|
self.total_num_heads = total_num_heads
|
||||||
|
if total_num_kv_heads is None:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
tp_size = self.comm_group.world_size
|
||||||
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
|
if tp_size >= self.total_num_kv_heads:
|
||||||
|
self.num_kv_heads = 1
|
||||||
|
self.num_kv_head_replicas = divide(tp_size,
|
||||||
|
self.total_num_kv_heads)
|
||||||
|
else:
|
||||||
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||||
|
self.num_kv_head_replicas = 1
|
||||||
|
input_size = self.hidden_size
|
||||||
|
output_size = (self.num_heads +
|
||||||
|
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||||
|
self.output_sizes = [
|
||||||
|
self.num_heads * self.head_size * tp_size, # q_proj
|
||||||
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||||
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||||
|
]
|
||||||
|
AscendColumnParallelLinear.__init__(self,
|
||||||
|
input_size=input_size,
|
||||||
|
output_size=output_size,
|
||||||
|
bias=bias,
|
||||||
|
gather_output=False,
|
||||||
|
skip_bias_add=skip_bias_add,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
if self.forward_type == "dense_optim":
|
||||||
|
return self._forward_dense_optim(input_)
|
||||||
|
else:
|
||||||
|
return super().forward(input_)
|
||||||
|
|
||||||
|
def _forward_dense_optim(
|
||||||
|
self, input_: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||||
|
communication-computation fusion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
layer_num = self.prefix.split('.')[2]
|
||||||
|
|
||||||
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
input_, layer_num != '0')
|
||||||
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
|
|
||||||
|
if self.gather_output:
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = self.comm_group.all_gather(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class AscendLinearBase(LinearBase):
|
class AscendLinearBase(LinearBase):
|
||||||
|
|
||||||
|
|||||||
63
vllm_ascend/ops/register_custom_ops.py
Normal file
63
vllm_ascend/ops/register_custom_ops.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_reduce_scatter)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||||
|
residual: torch.Tensor) -> torch.Tensor:
|
||||||
|
if get_forward_context().flashcomm_v1_enabled:
|
||||||
|
pad_size = get_forward_context().pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
|
||||||
|
|
||||||
|
return residual
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||||
|
label: bool) -> torch.Tensor:
|
||||||
|
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
||||||
|
if flashcomm_v1_enabled and label:
|
||||||
|
x = tensor_model_parallel_all_gather(x, 0)
|
||||||
|
pad_size = get_forward_context().pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
x = x[:-pad_size, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
||||||
|
if flashcomm_v1_enabled:
|
||||||
|
pad_size = get_forward_context().pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
|
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||||
|
else:
|
||||||
|
return tensor_model_parallel_all_reduce(x)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||||
|
op_func=_maybe_chunk_residual_impl,
|
||||||
|
fake_impl=lambda x, residual: residual,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||||
|
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||||
|
fake_impl=lambda x, label: x,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||||
|
op_func=_maybe_pad_and_reduce_impl,
|
||||||
|
fake_impl=lambda x: x,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
@@ -493,6 +493,7 @@ def register_ascend_customop():
|
|||||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||||
AscendMergedColumnParallelLinear,
|
AscendMergedColumnParallelLinear,
|
||||||
|
AscendQKVParallelLinear,
|
||||||
AscendRowParallelLinear)
|
AscendRowParallelLinear)
|
||||||
from vllm_ascend.ops.rotary_embedding import (
|
from vllm_ascend.ops.rotary_embedding import (
|
||||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||||
@@ -510,6 +511,8 @@ def register_ascend_customop():
|
|||||||
name="RowParallelLinear")
|
name="RowParallelLinear")
|
||||||
CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
|
CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
|
||||||
name="MergedColumnParallelLinear")
|
name="MergedColumnParallelLinear")
|
||||||
|
CustomOp.register_oot(_decorated_op_cls=AscendQKVParallelLinear,
|
||||||
|
name="QKVParallelLinear")
|
||||||
CustomOp.register_oot(
|
CustomOp.register_oot(
|
||||||
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
|
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
|
||||||
name="DeepseekScalingRotaryEmbedding")
|
name="DeepseekScalingRotaryEmbedding")
|
||||||
@@ -572,3 +575,7 @@ def mlp_tp_enable() -> bool:
|
|||||||
|
|
||||||
def matmul_allreduce_enable() -> bool:
|
def matmul_allreduce_enable() -> bool:
|
||||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||||
|
|
||||||
|
|
||||||
|
def dense_optim_enable() -> bool:
|
||||||
|
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.distributed import tensor_model_parallel_all_gather
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
@@ -1182,6 +1183,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
if get_forward_context().flashcomm_v1_enabled:
|
||||||
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||||
|
pad_size = get_forward_context().pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
hidden_states = hidden_states[:-pad_size, :]
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user