[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -21,6 +21,7 @@ def main():
|
|||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
enable_expert_parallel=True,
|
||||||
additional_config={
|
additional_config={
|
||||||
"torchair_graph_config": {
|
"torchair_graph_config": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
@@ -28,7 +29,6 @@ def main():
|
|||||||
"ascend_scheduler_config": {
|
"ascend_scheduler_config": {
|
||||||
"enabled": True
|
"enabled": True
|
||||||
},
|
},
|
||||||
"expert_tensor_parallel_size": 1
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
rm -rf ./.torchair_cache/
|
||||||
|
rm -rf ./dynamo_*
|
||||||
|
rm -rf /root/ascend/log/debug/plog/*
|
||||||
|
|
||||||
export HCCL_IF_IP=2.0.0.0
|
export HCCL_IF_IP=2.0.0.0
|
||||||
export GLOO_SOCKET_IFNAME="enp189s0f0"
|
export GLOO_SOCKET_IFNAME="enp189s0f0"
|
||||||
export TP_SOCKET_IFNAME="enp189s0f0"
|
export TP_SOCKET_IFNAME="enp189s0f0"
|
||||||
@@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0"
|
|||||||
export OMP_PROC_BIND=false
|
export OMP_PROC_BIND=false
|
||||||
export OMP_NUM_THREADS=100
|
export OMP_NUM_THREADS=100
|
||||||
|
|
||||||
export VLLM_USE_V1=0
|
export VLLM_USE_V1=1
|
||||||
|
export ASCEND_LAUNCH_BLOCKING=0
|
||||||
export ASCEND_RT_VISIBLE_DEVICES=0,1
|
|
||||||
export VLLM_DP_SIZE=2
|
|
||||||
export VLLM_DP_RANK=0
|
|
||||||
export VLLM_DP_MASTER_IP="2.0.0.0"
|
|
||||||
export VLLM_DP_MASTER_PORT=40001
|
|
||||||
export VLLM_DP_PROXY_IP="2.0.0.0"
|
|
||||||
export VLLM_DP_PROXY_PORT=30002
|
|
||||||
export VLLM_DP_MONITOR_PORT=30003
|
|
||||||
export VLLM_HTTP_PORT=20001
|
|
||||||
|
|
||||||
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
|
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
--port 20001 \
|
--port 20002 \
|
||||||
--tensor-parallel-size 1 \
|
|
||||||
--seed 1024 \
|
|
||||||
--served-model-name Qwen \
|
--served-model-name Qwen \
|
||||||
--max-model-len 2000 \
|
--data-parallel-size 4 \
|
||||||
--max-num-batched-tokens 2000 \
|
--data-parallel-size-local 4 \
|
||||||
|
--data-parallel-address 2.0.0.0 \
|
||||||
|
--data-parallel-rpc-port 13389 \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--enable-expert-parallel \
|
||||||
|
--no-enable-prefix-caching \
|
||||||
|
--max-num-seqs 16 \
|
||||||
|
--max-model-len 4096 \
|
||||||
|
--max-num-batched-tokens 4096 \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
--gpu-memory-utilization 0.9 \
|
--enforce-eager \
|
||||||
|
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}'
|
||||||
|
|||||||
@@ -114,7 +114,16 @@ def mock_distributed():
|
|||||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||||
_PP=pp_group):
|
_PP=pp_group), \
|
||||||
|
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_forward_context():
|
||||||
|
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||||
|
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
|
||||||
|
return_value=forward_context):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
|
|||||||
quant_config=None)
|
quant_config=None)
|
||||||
|
|
||||||
|
|
||||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config):
|
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
|
||||||
|
mock_forward_context):
|
||||||
base_config.n_shared_experts = 1
|
base_config.n_shared_experts = 1
|
||||||
moe = CustomDeepseekV2MoE(config=base_config,
|
moe = CustomDeepseekV2MoE(config=base_config,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
|
|||||||
@@ -18,16 +18,18 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_npu
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import get_fused_moe_state
|
||||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||||
AscendUnquantizedFusedMoEMethod)
|
AscendUnquantizedFusedMoEMethod)
|
||||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||||
|
|
||||||
adapt_patch(True)
|
adapt_patch(True)
|
||||||
|
|
||||||
|
|
||||||
def mock_ep_group(mocker):
|
def mock_ep_and_mc2_group(mocker):
|
||||||
mock_group = mocker.MagicMock()
|
mock_group = mocker.MagicMock()
|
||||||
mock_group.rank_in_group = 0
|
mock_group.rank_in_group = 0
|
||||||
mock_group.rank = 0
|
mock_group.rank = 0
|
||||||
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
|
|
||||||
with patch('torch.distributed.get_rank', return_value=0), \
|
with patch('torch.distributed.get_rank', return_value=0), \
|
||||||
patch('torch.distributed.get_world_size', return_value=4), \
|
patch('torch.distributed.get_world_size', return_value=4), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
|
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||||
|
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||||
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||||
return_value=MagicMock(
|
return_value=MagicMock(
|
||||||
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
|
max_tokens_across_dp=10,
|
||||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||||
)), \
|
)), \
|
||||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||||
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
|
|||||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||||
torch.randn(16, 2)
|
torch.randn(16, 2)
|
||||||
)):
|
)):
|
||||||
yield
|
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||||
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||||
|
torch.randn(16, 2))), \
|
||||||
|
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||||
|
torch.randn(16, 2))):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -237,11 +247,16 @@ class TestAscendFusedMoe:
|
|||||||
moe.moe_parallel_config.ep_size = 1
|
moe.moe_parallel_config.ep_size = 1
|
||||||
|
|
||||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||||
output = moe.forward(inputs,
|
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||||
router_logits,
|
dtype=torch.bool),
|
||||||
is_prefill=is_prefill,
|
padded_num_tokens=num_tokens)
|
||||||
top_k=top_k,
|
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||||
shared_experts=shared_experts)
|
return_value=forward_context):
|
||||||
|
output = moe.forward(inputs,
|
||||||
|
router_logits,
|
||||||
|
is_prefill=is_prefill,
|
||||||
|
top_k=top_k,
|
||||||
|
shared_experts=shared_experts)
|
||||||
|
|
||||||
moe.quant_method.apply.assert_called_once()
|
moe.quant_method.apply.assert_called_once()
|
||||||
|
|
||||||
@@ -288,15 +303,20 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||||
mock_moe_env, others_param):
|
mock_moe_env, others_param):
|
||||||
"""
|
"""
|
||||||
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
|
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
|
||||||
2 test use_select_experts and fused_experts
|
2 test use_select_experts and fused_experts
|
||||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||||
"""
|
"""
|
||||||
global_num_experts, ep_size, select_softmax = others_param
|
global_num_experts, ep_size, select_softmax = others_param
|
||||||
|
is_prefill = False
|
||||||
|
is_deepseek_v3_r1 = global_num_experts == 256
|
||||||
|
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
||||||
|
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||||
with patch(
|
with patch(
|
||||||
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
|
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
|
||||||
select_softmax):
|
select_softmax), \
|
||||||
|
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
|
||||||
moe_method.ep_size = ep_size
|
moe_method.ep_size = ep_size
|
||||||
x = torch.randn(8, 2, 2)
|
x = torch.randn(8, 2, 2)
|
||||||
router_logits = torch.randn(8, 8)
|
router_logits = torch.randn(8, 8)
|
||||||
@@ -309,7 +329,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
top_k=2,
|
top_k=2,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
is_prefill=False)
|
is_prefill=is_prefill)
|
||||||
|
|
||||||
if ep_size == 1:
|
if ep_size == 1:
|
||||||
assert result.shape == (16, 2)
|
assert result.shape == (16, 2)
|
||||||
@@ -327,8 +347,13 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
4 test use_select_experts and fused_experts
|
4 test use_select_experts and fused_experts
|
||||||
"""
|
"""
|
||||||
ep_size, alltoall_buffer = others_param
|
ep_size, alltoall_buffer = others_param
|
||||||
|
is_prefill = False
|
||||||
|
forward_context = MagicMock(
|
||||||
|
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
|
||||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||||
alltoall_buffer):
|
alltoall_buffer), \
|
||||||
|
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||||
|
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||||
moe_method.ep_size = ep_size
|
moe_method.ep_size = ep_size
|
||||||
x = torch.randn(8, 2, 2)
|
x = torch.randn(8, 2, 2)
|
||||||
@@ -347,7 +372,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
global_num_experts=128,
|
global_num_experts=128,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
is_prefill=False)
|
is_prefill=is_prefill)
|
||||||
|
|
||||||
if ep_size == 16 or ep_size == 1:
|
if ep_size == 16 or ep_size == 1:
|
||||||
assert result.shape == (16, 2)
|
assert result.shape == (16, 2)
|
||||||
|
|||||||
117
vllm_ascend/ascend_forward_context.py
Normal file
117
vllm_ascend/ascend_forward_context.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||||
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEState(Enum):
|
||||||
|
AllGather = 0
|
||||||
|
All2All = 1
|
||||||
|
MC2 = 2
|
||||||
|
AllGatherEP = 3
|
||||||
|
NaiveMulticast = 4
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||||
|
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||||
|
is_deepseek_v3_r1: bool):
|
||||||
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||||
|
# only supports deepseek v3/r1
|
||||||
|
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||||
|
and is_deepseek_v3_r1):
|
||||||
|
return FusedMoEState.AllGatherEP
|
||||||
|
elif ep_size == 1:
|
||||||
|
if with_prefill:
|
||||||
|
return FusedMoEState.NaiveMulticast
|
||||||
|
else:
|
||||||
|
return FusedMoEState.AllGather
|
||||||
|
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||||
|
elif ep_size < 16 or with_prefill:
|
||||||
|
return FusedMoEState.All2All
|
||||||
|
else:
|
||||||
|
return FusedMoEState.MC2
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_ascend_forward_context(
|
||||||
|
attn_metadata: Any,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
num_tokens: Optional[int] = None,
|
||||||
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
|
with_prefill: bool = True,
|
||||||
|
in_profile_run: bool = False,
|
||||||
|
num_actual_tokens: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""A context manager that stores the current forward context,
|
||||||
|
can be attention metadata, etc.
|
||||||
|
We add some additional param into forward_context.
|
||||||
|
"""
|
||||||
|
with set_forward_context(attn_metadata,
|
||||||
|
vllm_config,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp):
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
forward_context.with_prefill = with_prefill
|
||||||
|
ep_size = (get_ep_group().world_size if
|
||||||
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||||
|
|
||||||
|
is_deepseek_v3_r1 = hasattr(
|
||||||
|
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||||
|
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||||
|
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
|
||||||
|
is_deepseek_v3_r1)
|
||||||
|
|
||||||
|
forward_context.fused_moe_state = fused_moe_state
|
||||||
|
|
||||||
|
forward_context.in_profile_run = in_profile_run
|
||||||
|
|
||||||
|
# NOTE: This cannot be set using set_forward_context
|
||||||
|
# due to multiple warmups before actual capturing
|
||||||
|
forward_context.capturing = False
|
||||||
|
|
||||||
|
if num_tokens is None and attn_metadata is not None:
|
||||||
|
if hasattr(attn_metadata, 'num_actual_tokens'):
|
||||||
|
# for v1 engine
|
||||||
|
num_tokens = attn_metadata.num_actual_tokens
|
||||||
|
else:
|
||||||
|
# for v0 engine
|
||||||
|
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
|
if num_actual_tokens is None:
|
||||||
|
num_actual_tokens = num_tokens
|
||||||
|
|
||||||
|
dp_world_size = get_dp_group().world_size
|
||||||
|
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||||
|
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_tokens_across_dp = num_tokens
|
||||||
|
|
||||||
|
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||||
|
|
||||||
|
if num_tokens is not None:
|
||||||
|
tp_world_size = get_tp_group().world_size
|
||||||
|
# NOTE: token num which need to pad to when mc2
|
||||||
|
forward_context.padded_num_tokens = math.ceil(
|
||||||
|
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||||
|
|
||||||
|
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=current_platform.device_type)
|
||||||
|
mc2_mask[:num_actual_tokens] = True
|
||||||
|
forward_context.mc2_mask = mc2_mask
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
@@ -119,6 +119,7 @@ class AscendAttentionState(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
|
|
||||||
# **************************** Basic Properties ****************************
|
# **************************** Basic Properties ****************************
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
@@ -149,11 +150,6 @@ class AscendMetadata:
|
|||||||
# (num_tokens,)
|
# (num_tokens,)
|
||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
|
|
||||||
# ************************* DP Related Properties **************************
|
|
||||||
with_prefill_across_dp: bool = False
|
|
||||||
# Maximum number of tokens across dp group
|
|
||||||
max_num_tokens_across_dp: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
|
|
||||||
@@ -164,12 +160,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build(self,
|
def build(self, num_reqs, num_actual_tokens, max_query_len):
|
||||||
num_reqs,
|
|
||||||
num_actual_tokens,
|
|
||||||
max_query_len,
|
|
||||||
max_num_tokens_across_dp: int = 0,
|
|
||||||
with_prefill_across_dp: bool = False):
|
|
||||||
|
|
||||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||||
)
|
)
|
||||||
@@ -196,18 +187,15 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||||
num_actual_tokens=num_actual_tokens,
|
block_tables=block_table,
|
||||||
block_tables=block_table,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc=query_start_loc,
|
query_lens=query_lens,
|
||||||
query_lens=query_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens=seq_lens,
|
max_query_len=max_query_len,
|
||||||
max_query_len=max_query_len,
|
slot_mapping=slot_mapping,
|
||||||
slot_mapping=slot_mapping,
|
attn_mask=attn_mask,
|
||||||
attn_mask=attn_mask,
|
attn_state=attn_state)
|
||||||
attn_state=attn_state,
|
|
||||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
|
||||||
with_prefill_across_dp=with_prefill_across_dp)
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -127,8 +127,6 @@ class AscendTorchairMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
# max value of number of tokens across dp group
|
|
||||||
max_num_tokens_across_dp: int = 0
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||||
@@ -139,7 +137,7 @@ class AscendTorchairMetadata:
|
|||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
with_prefill_across_dp: bool = False
|
|
||||||
decode: Optional[AscendDecodeMetadata] = None
|
decode: Optional[AscendDecodeMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
|
|
||||||
return graph_block_tables[:num_seqs, :max_blocks]
|
return graph_block_tables[:num_seqs, :max_blocks]
|
||||||
|
|
||||||
def build_dummy(self, num_reqs: int,
|
def build_torchair_graph_dummy(
|
||||||
num_actual_tokens: int) -> AscendTorchairMetadata:
|
self, num_reqs: int,
|
||||||
|
num_actual_tokens: int) -> AscendTorchairMetadata:
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
_, max_blocks = self.runner.graph_block_tables.shape
|
_, max_blocks = self.runner.graph_block_tables.shape
|
||||||
block_table = torch.zeros((num_reqs, max_blocks),
|
block_table = torch.zeros((num_reqs, max_blocks),
|
||||||
@@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
max_num_tokens_across_dp=num_reqs,
|
|
||||||
decode=decode_metadata)
|
decode=decode_metadata)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1):
|
||||||
max_num_tokens_across_dp: int = 0,
|
|
||||||
with_prefill_across_dp: bool = False):
|
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
@@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
pad_value = 1
|
pad_value = 1
|
||||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||||
] * graph_pad_size
|
] * graph_pad_size
|
||||||
max_num_tokens_across_dp = len(padded_seq_lens)
|
|
||||||
|
|
||||||
seq_lens = torch.from_numpy(
|
seq_lens = torch.from_numpy(
|
||||||
np.array(padded_seq_lens).astype(np.int32))
|
np.array(padded_seq_lens).astype(np.int32))
|
||||||
@@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state)
|
||||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
|
||||||
with_prefill_across_dp=with_prefill_across_dp)
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -126,9 +126,6 @@ class AscendMLAMetadata:
|
|||||||
# For logging.
|
# For logging.
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
max_num_tokens_across_dp: int = 0
|
|
||||||
with_prefill_across_dp: bool = False
|
|
||||||
|
|
||||||
query_lens: Optional[list[int]] = None
|
query_lens: Optional[list[int]] = None
|
||||||
# The dimension of the attention heads
|
# The dimension of the attention heads
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
@@ -302,8 +299,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
return graph_block_tables[:num_seqs, :max_blocks]
|
return graph_block_tables[:num_seqs, :max_blocks]
|
||||||
|
|
||||||
def build_dummy(self, num_reqs: int,
|
def build_torchair_graph_dummy(
|
||||||
num_actual_tokens: int) -> AscendMLAMetadata:
|
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
_, max_blocks = self.runner.graph_block_tables.shape
|
_, max_blocks = self.runner.graph_block_tables.shape
|
||||||
block_table = torch.zeros((num_reqs, max_blocks),
|
block_table = torch.zeros((num_reqs, max_blocks),
|
||||||
@@ -353,8 +350,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
num_actual_tokens: int,
|
num_actual_tokens: int,
|
||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
max_num_tokens_across_dp: int = 0,
|
|
||||||
with_prefill_across_dp: bool = False,
|
|
||||||
query_start_loc: torch.Tensor = None,
|
query_start_loc: torch.Tensor = None,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
@@ -498,8 +493,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
|
||||||
with_prefill_across_dp=with_prefill_across_dp,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
48
vllm_ascend/distributed/parallel_state.py
Normal file
48
vllm_ascend/distributed/parallel_state.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||||
|
init_model_parallel_group)
|
||||||
|
|
||||||
|
# Currently, mc2 op need their own group coordinator.
|
||||||
|
_MC2: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mc2_group() -> GroupCoordinator:
|
||||||
|
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||||
|
return _MC2
|
||||||
|
|
||||||
|
|
||||||
|
def model_parallel_initialized():
|
||||||
|
return (_MC2 is not None)
|
||||||
|
|
||||||
|
|
||||||
|
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||||
|
if model_parallel_initialized():
|
||||||
|
return
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||||
|
|
||||||
|
# The layout of all ranks: ExternalDP * EP
|
||||||
|
# ExternalDP is the data parallel group that is not part of the model,
|
||||||
|
# every dp rank can generate independently (in verl integration).
|
||||||
|
all_ranks = torch.arange(world_size).reshape(
|
||||||
|
-1, parallel_config.data_parallel_size *
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
|
global _MC2
|
||||||
|
group_ranks = all_ranks.unbind(0)
|
||||||
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
|
||||||
|
_MC2 = init_model_parallel_group(group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
group_name="mc2")
|
||||||
|
|
||||||
|
|
||||||
|
def destroy_ascend_model_parallel():
|
||||||
|
global _MC2
|
||||||
|
if _MC2:
|
||||||
|
_MC2.destroy()
|
||||||
|
_MC2 = None
|
||||||
@@ -174,20 +174,12 @@ class CustomDeepseekDBOMoE(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
if attn_metadata is None:
|
forward_context = get_forward_context()
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# when profile runs, force experts to load balanced tokens
|
# when profile runs, force experts to load balanced tokens
|
||||||
# to avoid high memory consumption on a single rank.
|
# to avoid high memory consumption on a single rank.
|
||||||
# TODO: need a better flag to indicate whether in profile run or not.
|
enable_force_load_balance = forward_context.in_profile_run
|
||||||
if attn_metadata is None:
|
|
||||||
# for profile run
|
is_prefill = forward_context.with_prefill
|
||||||
is_prefill = True
|
|
||||||
enable_force_load_balance = True
|
|
||||||
else:
|
|
||||||
is_prefill = attn_metadata.num_prefills > 0
|
|
||||||
enable_force_load_balance = False
|
|
||||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
|
||||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
|
||||||
|
|
||||||
old_hidden_states = hidden_states.clone()
|
old_hidden_states = hidden_states.clone()
|
||||||
|
|
||||||
|
|||||||
@@ -377,20 +377,14 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
replace_allreduce: bool = False) -> torch.Tensor:
|
replace_allreduce: bool = False) -> torch.Tensor:
|
||||||
|
|
||||||
if attn_metadata is None:
|
forward_context = get_forward_context()
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# when profile runs, force experts to load balanced tokens
|
# when profile runs, force experts to load balanced tokens
|
||||||
# to avoid high memory consumption on a single rank.
|
# to avoid high memory consumption on a single rank.
|
||||||
# TODO: need a better flag to indicate whether in profile run or not.
|
|
||||||
if attn_metadata is None:
|
enable_force_load_balance = forward_context.in_profile_run
|
||||||
# for profile run
|
|
||||||
is_prefill = True
|
is_prefill = forward_context.with_prefill
|
||||||
enable_force_load_balance = True
|
|
||||||
else:
|
|
||||||
is_prefill = attn_metadata.num_prefills > 0
|
|
||||||
enable_force_load_balance = False
|
|
||||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
|
||||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
|
||||||
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
|
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
|
||||||
# the behaviour aligned between dummy_run and normal model_execute.
|
# the behaviour aligned between dummy_run and normal model_execute.
|
||||||
if self.kv_consumer:
|
if self.kv_consumer:
|
||||||
@@ -572,9 +566,10 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
forward_context = get_forward_context()
|
||||||
enable_multistream_mla = (self.enable_multistream_mla
|
enable_multistream_mla = (self.enable_multistream_mla
|
||||||
and attn_metadata is not None
|
and attn_metadata is not None
|
||||||
and not attn_metadata.with_prefill_across_dp
|
and not forward_context.with_prefill
|
||||||
and attn_metadata.num_decodes > 0)
|
and attn_metadata.num_decodes > 0)
|
||||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
|
|||||||
@@ -837,12 +837,8 @@ class PanguProMoEModel(nn.Module):
|
|||||||
# if attn_meatadata is not passed, we try to get it from forward_context.
|
# if attn_meatadata is not passed, we try to get it from forward_context.
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata is None:
|
|
||||||
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
|
max_tokens_across_dp = get_forward_context().max_tokens_across_dp
|
||||||
# are same.
|
|
||||||
max_tokens_across_dp = hidden_states.shape[0]
|
|
||||||
else:
|
|
||||||
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
|
||||||
|
|
||||||
tp_size = get_tp_group().world_size
|
tp_size = get_tp_group().world_size
|
||||||
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
|
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
|
||||||
|
|||||||
@@ -223,7 +223,6 @@ def model_input_split_v1_mla_attn(
|
|||||||
attn_mask=attn_mask_pre,
|
attn_mask=attn_mask_pre,
|
||||||
prefill=prefill_pre,
|
prefill=prefill_pre,
|
||||||
decode=decode_pre,
|
decode=decode_pre,
|
||||||
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
|
|
||||||
)
|
)
|
||||||
attention_metadata_post = _metadata_cls(
|
attention_metadata_post = _metadata_cls(
|
||||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||||
@@ -240,6 +239,5 @@ def model_input_split_v1_mla_attn(
|
|||||||
attn_state=attn_state_post,
|
attn_state=attn_state_post,
|
||||||
prefill=prefill_post,
|
prefill=prefill_post,
|
||||||
decode=decode_post,
|
decode=decode_post,
|
||||||
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
|
|
||||||
)
|
)
|
||||||
return [attention_metadata_pre, attention_metadata_post]
|
return [attention_metadata_pre, attention_metadata_post]
|
||||||
|
|||||||
@@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \
|
|||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.communication_op import \
|
from vllm_ascend.distributed.communication_op import \
|
||||||
data_parallel_reduce_scatter
|
data_parallel_reduce_scatter
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||||
get_all_reduce_merge_state, get_fused_moe_state,
|
get_all_reduce_merge_state,
|
||||||
|
get_ascend_soc_version,
|
||||||
get_rm_router_logits_state, is_310p)
|
get_rm_router_logits_state, is_310p)
|
||||||
|
|
||||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||||
@@ -127,9 +130,23 @@ def fused_experts_with_mc2(
|
|||||||
moe_parallel_config: FusedMoEParallelConfig,
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
expert_map: torch.Tensor = None,
|
expert_map: torch.Tensor = None,
|
||||||
moe_all_to_all_group_name: Optional[str] = None,
|
moe_all_to_all_group_name: Optional[str] = None,
|
||||||
shared_experts: Optional[Any] = None
|
shared_experts: Optional[Any] = None,
|
||||||
|
is_torchair: bool = False,
|
||||||
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
global_bs = 0
|
quant_mode = 0
|
||||||
|
ep_rank_id = moe_parallel_config.ep_rank
|
||||||
|
ep_world_size = moe_parallel_config.ep_size
|
||||||
|
|
||||||
|
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||||
|
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||||
|
or is_torchair)
|
||||||
|
|
||||||
|
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||||
|
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||||
|
|
||||||
|
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||||
|
|
||||||
moe_expert_num = len(expert_map)
|
moe_expert_num = len(expert_map)
|
||||||
kwargs_mc2 = {
|
kwargs_mc2 = {
|
||||||
"x": hidden_states,
|
"x": hidden_states,
|
||||||
@@ -137,32 +154,35 @@ def fused_experts_with_mc2(
|
|||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": moe_expert_num,
|
"moe_expert_num": moe_expert_num,
|
||||||
"global_bs": global_bs,
|
"global_bs": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
|
|
||||||
quant_mode = 0
|
|
||||||
ep_rank_id = moe_parallel_config.ep_rank
|
|
||||||
ep_world_size = moe_parallel_config.ep_size
|
|
||||||
|
|
||||||
tp_world_size = moe_parallel_config.tp_size
|
|
||||||
tp_rank = rank % tp_world_size
|
|
||||||
|
|
||||||
stage1_kwargs = {
|
stage1_kwargs = {
|
||||||
"scales": None,
|
"scales": None,
|
||||||
"quant_mode": quant_mode,
|
"quant_mode": quant_mode,
|
||||||
"group_ep": moe_all_to_all_group_name,
|
"group_ep": moe_all_to_all_group_name,
|
||||||
"ep_world_size": ep_world_size,
|
"ep_world_size": ep_world_size,
|
||||||
"ep_rank_id": ep_rank_id,
|
"ep_rank_id": ep_rank_id,
|
||||||
"group_tp": moe_all_to_all_group_name,
|
|
||||||
"tp_world_size": tp_world_size,
|
|
||||||
"tp_rank_id": tp_rank,
|
|
||||||
}
|
}
|
||||||
|
if need_extra_args:
|
||||||
|
stage1_kwargs.update({
|
||||||
|
"group_tp": moe_all_to_all_group_name,
|
||||||
|
"tp_world_size": 1,
|
||||||
|
"tp_rank_id": 0,
|
||||||
|
})
|
||||||
|
if a3_need_extra_args and enable_dispatch_v2:
|
||||||
|
stage1_kwargs.update({
|
||||||
|
"x_active_mask": mc2_mask,
|
||||||
|
})
|
||||||
|
|
||||||
kwargs_mc2.update(stage1_kwargs)
|
kwargs_mc2.update(stage1_kwargs)
|
||||||
|
|
||||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
**kwargs_mc2
|
||||||
|
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||||
|
**kwargs_mc2)
|
||||||
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||||
0:5]
|
0:5]
|
||||||
|
|
||||||
if shared_experts is not None:
|
if shared_experts is not None:
|
||||||
@@ -205,7 +225,6 @@ def fused_experts_with_mc2(
|
|||||||
kwargs_mc2 = {
|
kwargs_mc2 = {
|
||||||
"expand_x": down_out_list,
|
"expand_x": down_out_list,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expand_idx": expand_idx,
|
|
||||||
"expert_scales": topk_weights.to(torch.float32),
|
"expert_scales": topk_weights.to(torch.float32),
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
@@ -218,15 +237,33 @@ def fused_experts_with_mc2(
|
|||||||
"group_ep": moe_all_to_all_group_name,
|
"group_ep": moe_all_to_all_group_name,
|
||||||
"ep_world_size": ep_world_size,
|
"ep_world_size": ep_world_size,
|
||||||
"ep_rank_id": ep_rank_id,
|
"ep_rank_id": ep_rank_id,
|
||||||
"tp_send_counts": tp_recv_counts,
|
|
||||||
# "group_tp": self.moe_rs_group_name,
|
|
||||||
"group_tp": moe_all_to_all_group_name,
|
|
||||||
"tp_world_size": tp_world_size,
|
|
||||||
"tp_rank_id": tp_rank,
|
|
||||||
}
|
}
|
||||||
|
if enable_dispatch_v2:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"assist_info_for_combine":
|
||||||
|
assist_info_for_combine,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"expand_idx": assist_info_for_combine,
|
||||||
|
})
|
||||||
|
if need_extra_args:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"tp_send_counts": tp_recv_counts,
|
||||||
|
"group_tp": moe_all_to_all_group_name,
|
||||||
|
"tp_world_size": 1,
|
||||||
|
"tp_rank_id": 0,
|
||||||
|
})
|
||||||
|
if a3_need_extra_args and enable_dispatch_v2:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"x_active_mask": mc2_mask,
|
||||||
|
})
|
||||||
kwargs_mc2.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
|
|
||||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||||
|
**kwargs_mc2
|
||||||
|
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||||
|
**kwargs_mc2)
|
||||||
|
|
||||||
if shared_experts is None:
|
if shared_experts is None:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
super().__init__(moe=moe)
|
super().__init__(moe=moe)
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
self.ep_group = get_ep_group()
|
|
||||||
self.ep_size = self.moe.moe_parallel_config.ep_size
|
|
||||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = self.ep_group.device_group
|
device_group = get_mc2_group().device_group
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
local_rank = torch.distributed.get_rank(group=device_group)
|
local_rank = torch.distributed.get_rank(group=device_group)
|
||||||
backend = device_group._get_backend(torch.device("npu"))
|
backend = device_group._get_backend(torch.device("npu"))
|
||||||
@@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill,
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
is_deepseek_v3_r1)
|
|
||||||
if fused_moe_state == FusedMoEState.MC2:
|
if fused_moe_state == FusedMoEState.MC2:
|
||||||
return fused_experts_with_mc2(
|
return fused_experts_with_mc2(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
shared_experts=shared_experts)
|
shared_experts=shared_experts,
|
||||||
|
mc2_mask=kwargs.get("mc2_mask", None))
|
||||||
elif fused_moe_state in [
|
elif fused_moe_state in [
|
||||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||||
]:
|
]:
|
||||||
@@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
real_top_k = self.top_k
|
real_top_k = self.top_k
|
||||||
|
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
|
||||||
|
|
||||||
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
forward_context = get_forward_context()
|
||||||
is_prefill, is_deepseek_v3_r1)
|
fused_moe_state = forward_context.fused_moe_state
|
||||||
|
mc2_mask = forward_context.mc2_mask
|
||||||
if shared_experts:
|
if shared_experts:
|
||||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||||
shared_hidden_states = shared_experts(hidden_states)
|
shared_hidden_states = shared_experts(hidden_states)
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if (tp_size > 1 and fused_moe_state not in [
|
if (fused_moe_state not in [
|
||||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
FusedMoEState.NaiveMulticast
|
FusedMoEState.NaiveMulticast
|
||||||
] and not replace_allreduce):
|
] and not replace_allreduce):
|
||||||
if num_tokens < tp_size:
|
if fused_moe_state in {FusedMoEState.MC2}:
|
||||||
|
padding_size = forward_context.padded_num_tokens
|
||||||
|
else:
|
||||||
|
# TODO: Determine if we can remove the padding
|
||||||
|
padding_size = tp_size
|
||||||
|
if num_tokens < padding_size:
|
||||||
hidden_states = nn.functional.pad(
|
hidden_states = nn.functional.pad(
|
||||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||||
router_logits = nn.functional.pad(
|
router_logits = nn.functional.pad(
|
||||||
router_logits, (0, 0, 0, tp_size - num_tokens))
|
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
if tp_size > 1:
|
||||||
tp_size,
|
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||||
dim=0)
|
tp_size,
|
||||||
chunk_router_logits = torch.tensor_split(router_logits,
|
dim=0)
|
||||||
tp_size,
|
chunk_router_logits = torch.tensor_split(router_logits,
|
||||||
dim=0)
|
tp_size,
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
dim=0)
|
||||||
hidden_states = chunk_hidden_states[tp_rank]
|
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
|
||||||
router_logits = chunk_router_logits[tp_rank]
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
hidden_states = chunk_hidden_states[tp_rank]
|
||||||
|
router_logits = chunk_router_logits[tp_rank]
|
||||||
|
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
if fused_moe_state == FusedMoEState.AllGather:
|
if fused_moe_state == FusedMoEState.AllGather:
|
||||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||||
if not self.torchair_graph_enabled:
|
if not self.torchair_graph_enabled:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||||
if attn_metadata is not None:
|
if num_tokens < max_tokens_across_dp:
|
||||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
hidden_states = nn.functional.pad(
|
||||||
if num_tokens < max_num_tokens_across_dp:
|
hidden_states,
|
||||||
hidden_states = nn.functional.pad(
|
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||||
hidden_states,
|
if not self.rm_router_logits:
|
||||||
(0, 0, 0,
|
router_logits = nn.functional.pad(
|
||||||
max_num_tokens_across_dp - num_tokens))
|
router_logits,
|
||||||
if not self.rm_router_logits:
|
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||||
router_logits = nn.functional.pad(
|
|
||||||
router_logits,
|
|
||||||
(0, 0, 0,
|
|
||||||
max_num_tokens_across_dp - num_tokens))
|
|
||||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||||
if self.rm_router_logits:
|
if self.rm_router_logits:
|
||||||
router_logits, _ = gate(hidden_states)
|
router_logits, _ = gate(hidden_states)
|
||||||
@@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||||
and self.enable_multistream_moe and not is_prefill else None,
|
and self.enable_multistream_moe and not is_prefill else None,
|
||||||
|
mc2_mask=mc2_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_experts:
|
if shared_experts:
|
||||||
if isinstance(e_hidden_states, tuple):
|
if isinstance(e_hidden_states, tuple):
|
||||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||||
|
|
||||||
if (tp_size > 1 and fused_moe_state not in [
|
if (fused_moe_state not in [
|
||||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
FusedMoEState.NaiveMulticast
|
FusedMoEState.NaiveMulticast
|
||||||
] and not replace_allreduce):
|
] and not replace_allreduce):
|
||||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
if tp_size > 1:
|
||||||
self.tp_group)
|
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
self.tp_group)
|
||||||
if num_tokens < tp_size:
|
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||||
|
else:
|
||||||
|
final_hidden_states = e_hidden_states
|
||||||
|
if num_tokens < padding_size:
|
||||||
final_hidden_states = final_hidden_states[:num_tokens]
|
final_hidden_states = final_hidden_states[:num_tokens]
|
||||||
dispose_tensor(e_hidden_states)
|
dispose_tensor(e_hidden_states)
|
||||||
elif self.dp_size > 1:
|
elif self.dp_size > 1:
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
|
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
|
||||||
wrapper_rmsnorm_init)
|
|
||||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||||
AscendW8A8LinearMethod)
|
AscendW8A8LinearMethod)
|
||||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||||
@@ -81,9 +80,6 @@ class VLLMAscendQuantizer:
|
|||||||
VLLMAscendQuantizer.apply_patch(
|
VLLMAscendQuantizer.apply_patch(
|
||||||
"vllm.model_executor.layers.layernorm.RMSNorm",
|
"vllm.model_executor.layers.layernorm.RMSNorm",
|
||||||
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
||||||
VLLMAscendQuantizer.apply_patch(
|
|
||||||
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
|
|
||||||
"load_model", [wrapper_load_model])
|
|
||||||
break
|
break
|
||||||
VLLMAscendQuantizer.patched = True
|
VLLMAscendQuantizer.patched = True
|
||||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
logger.info("Using the vLLM Ascend Quantizer version now!")
|
||||||
|
|||||||
@@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed import GroupCoordinator
|
from vllm.distributed import GroupCoordinator, get_ep_group
|
||||||
from vllm.distributed.parallel_state import get_ep_group
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
import vllm_ascend.envs as envs
|
import vllm_ascend.envs as envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe import select_experts
|
from vllm_ascend.ops.fused_moe import select_experts
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||||
dispose_tensor, get_fused_moe_state)
|
dispose_tensor, get_ascend_soc_version)
|
||||||
|
|
||||||
|
|
||||||
def apply_mlp(hidden_states: torch.Tensor,
|
def apply_mlp(hidden_states: torch.Tensor,
|
||||||
@@ -118,10 +120,29 @@ def fused_experts_with_mc2(
|
|||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
|
is_torchair: bool = False,
|
||||||
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
assert mc2_mask is not None
|
||||||
if log2phy is not None:
|
if log2phy is not None:
|
||||||
topk_ids = log2phy[topk_ids]
|
topk_ids = log2phy[topk_ids]
|
||||||
global_bs = 0
|
|
||||||
|
quant_mode = 2
|
||||||
|
ep_group = get_mc2_group()
|
||||||
|
ep_rank_id = ep_group.rank_in_group
|
||||||
|
ep_world_size = ep_group.world_size
|
||||||
|
|
||||||
|
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||||
|
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||||
|
or is_torchair)
|
||||||
|
|
||||||
|
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||||
|
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||||
|
|
||||||
|
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||||
|
|
||||||
if (expert_map is not None):
|
if (expert_map is not None):
|
||||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||||
else:
|
else:
|
||||||
@@ -133,47 +154,43 @@ def fused_experts_with_mc2(
|
|||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": moe_expert_num,
|
"moe_expert_num": moe_expert_num,
|
||||||
"global_bs": global_bs,
|
"global_bs": 0,
|
||||||
"expert_scales": topk_weights.to(torch.float32),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
|
|
||||||
quant_mode = 2
|
|
||||||
ep_group = get_ep_group().device_group
|
|
||||||
local_rank = torch.distributed.get_rank(group=ep_group)
|
|
||||||
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
tp_size = world_size // all_to_all_group_size
|
|
||||||
tp_rank = rank % tp_size
|
|
||||||
|
|
||||||
stage1_kwargs = {
|
stage1_kwargs = {
|
||||||
"scales": None,
|
"scales": None,
|
||||||
"quant_mode": quant_mode,
|
"quant_mode": quant_mode,
|
||||||
"group_ep": moe_all_to_all_group_name,
|
"group_ep": moe_all_to_all_group_name,
|
||||||
"ep_world_size": all_to_all_group_size,
|
"ep_world_size": ep_world_size,
|
||||||
"ep_rank_id": local_rank,
|
"ep_rank_id": ep_rank_id,
|
||||||
# "group_tp": self.moe_rs_group_name,
|
|
||||||
"group_tp": moe_all_to_all_group_name,
|
|
||||||
"tp_world_size": tp_size,
|
|
||||||
"tp_rank_id": tp_rank,
|
|
||||||
}
|
}
|
||||||
|
if need_extra_args:
|
||||||
|
stage1_kwargs.update({
|
||||||
|
"group_tp": moe_all_to_all_group_name,
|
||||||
|
"tp_world_size": 1,
|
||||||
|
"tp_rank_id": 0,
|
||||||
|
})
|
||||||
|
if a3_need_extra_args and enable_dispatch_v2:
|
||||||
|
stage1_kwargs.update({
|
||||||
|
"x_active_mask": mc2_mask,
|
||||||
|
})
|
||||||
kwargs_mc2.update(stage1_kwargs)
|
kwargs_mc2.update(stage1_kwargs)
|
||||||
|
|
||||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||||
|
**kwargs_mc2
|
||||||
|
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||||
|
**kwargs_mc2)
|
||||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||||
0:7]
|
0:5]
|
||||||
|
|
||||||
if shared_experts is not None:
|
if shared_experts is not None:
|
||||||
with npu_stream_switch("moe_secondary", 0):
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
npu_wait_tensor(hidden_states, topk_weights)
|
npu_wait_tensor(quantized_x_for_share, expand_x)
|
||||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
shared_act_out = shared_experts.act_fn(
|
||||||
npu_wait_tensor(shared_gate_up[0], expand_x)
|
(quantized_x_for_share, dynamic_scale_for_share))
|
||||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
|
||||||
|
|
||||||
# `expand_x` will be disposed in the `apply_mlp` function
|
|
||||||
down_out_list = apply_mlp(expand_x,
|
down_out_list = apply_mlp(expand_x,
|
||||||
w1,
|
w1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -186,13 +203,11 @@ def fused_experts_with_mc2(
|
|||||||
kwargs_mc2 = {
|
kwargs_mc2 = {
|
||||||
"expand_x": down_out_list,
|
"expand_x": down_out_list,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expand_idx": expand_idx,
|
|
||||||
"expert_scales": topk_weights.to(torch.float32),
|
"expert_scales": topk_weights.to(torch.float32),
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": moe_expert_num,
|
"moe_expert_num": moe_expert_num,
|
||||||
"global_bs": 0,
|
"global_bs": 0,
|
||||||
"expand_scales": expand_scales,
|
|
||||||
}
|
}
|
||||||
tp_recv_counts = torch.empty(1,
|
tp_recv_counts = torch.empty(1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -200,24 +215,43 @@ def fused_experts_with_mc2(
|
|||||||
stage3_kwargs = {
|
stage3_kwargs = {
|
||||||
"ep_send_counts": ep_recv_counts,
|
"ep_send_counts": ep_recv_counts,
|
||||||
"group_ep": moe_all_to_all_group_name,
|
"group_ep": moe_all_to_all_group_name,
|
||||||
"ep_world_size": all_to_all_group_size,
|
"ep_world_size": ep_world_size,
|
||||||
"ep_rank_id": local_rank,
|
"ep_rank_id": ep_rank_id,
|
||||||
"tp_send_counts": tp_recv_counts,
|
|
||||||
# "group_tp": self.moe_rs_group_name,
|
|
||||||
"group_tp": moe_all_to_all_group_name,
|
|
||||||
"tp_world_size": tp_size,
|
|
||||||
"tp_rank_id": tp_rank,
|
|
||||||
}
|
}
|
||||||
|
if enable_dispatch_v2:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"assist_info_for_combine":
|
||||||
|
assist_info_for_combine,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"expand_idx": assist_info_for_combine,
|
||||||
|
})
|
||||||
|
if need_extra_args:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"tp_send_counts": tp_recv_counts,
|
||||||
|
"group_tp": moe_all_to_all_group_name,
|
||||||
|
"tp_world_size": 1,
|
||||||
|
"tp_rank_id": 0,
|
||||||
|
})
|
||||||
|
if a3_need_extra_args and enable_dispatch_v2:
|
||||||
|
stage3_kwargs.update({
|
||||||
|
"x_active_mask": mc2_mask,
|
||||||
|
})
|
||||||
kwargs_mc2.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
|
|
||||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||||
|
**kwargs_mc2
|
||||||
|
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||||
|
**kwargs_mc2)
|
||||||
|
|
||||||
if shared_experts is None:
|
if shared_experts is None:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
with npu_stream_switch("moe_secondary", 0):
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
npu_wait_tensor(shared_act[0], down_out_list)
|
npu_wait_tensor(shared_act, down_out_list)
|
||||||
shared_output, _ = shared_experts.down_proj(shared_act)
|
shared_output, _ = shared_experts.down_proj(
|
||||||
|
(shared_act, swiglu_out_scale))
|
||||||
return hidden_states, shared_output
|
return hidden_states, shared_output
|
||||||
|
|
||||||
|
|
||||||
@@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = self.ep_group.device_group
|
device_group = get_mc2_group().device_group
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
local_rank = torch.distributed.get_rank(group=device_group)
|
local_rank = torch.distributed.get_rank(group=device_group)
|
||||||
backend = device_group._get_backend(torch.device("npu"))
|
backend = device_group._get_backend(torch.device("npu"))
|
||||||
@@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
is_prefill, is_deepseek_v3_r1)
|
|
||||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||||
return fused_experts_with_allgather(
|
return fused_experts_with_allgather(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
shared_experts=shared_experts)
|
shared_experts=shared_experts,
|
||||||
|
is_torchair=self.torchair_graph_enabled,
|
||||||
|
mc2_mask=kwargs.get("mc2_mask", None))
|
||||||
elif fused_moe_state in [
|
elif fused_moe_state in [
|
||||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||||
]:
|
]:
|
||||||
|
|||||||
@@ -52,13 +52,3 @@ class NPUTorchairWorker(NPUWorker):
|
|||||||
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
||||||
|
|
||||||
return available_kv_cache_memory
|
return available_kv_cache_memory
|
||||||
|
|
||||||
def _get_max_num_tokens_and_with_prefill(self):
|
|
||||||
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""
|
|
||||||
|
|
||||||
max_num_tokens, with_prefill = super(
|
|
||||||
)._get_max_num_tokens_and_with_prefill()
|
|
||||||
if not with_prefill:
|
|
||||||
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
|
|
||||||
max_num_tokens)
|
|
||||||
return max_num_tokens, with_prefill
|
|
||||||
|
|||||||
@@ -429,15 +429,6 @@ def npu_prefetch(input: torch.Tensor,
|
|||||||
torch_npu.npu_prefetch(input, dependency, max_size)
|
torch_npu.npu_prefetch(input, dependency, max_size)
|
||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): move this into forward_context
|
|
||||||
class FusedMoEState(Enum):
|
|
||||||
AllGather = 0
|
|
||||||
All2All = 1
|
|
||||||
MC2 = 2
|
|
||||||
AllGatherEP = 3
|
|
||||||
NaiveMulticast = 4
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(ttanzhiqiang): rm_router_logits
|
# TODO(ttanzhiqiang): rm_router_logits
|
||||||
# dp>1 will trigger
|
# dp>1 will trigger
|
||||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||||
@@ -468,26 +459,6 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
|
||||||
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
|
||||||
is_deepseek_v3_r1: bool):
|
|
||||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
|
||||||
# only supports deepseek v3/r1
|
|
||||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
|
||||||
and is_deepseek_v3_r1):
|
|
||||||
return FusedMoEState.AllGatherEP
|
|
||||||
elif ep_size == 1:
|
|
||||||
if with_prefill:
|
|
||||||
return FusedMoEState.NaiveMulticast
|
|
||||||
else:
|
|
||||||
return FusedMoEState.AllGather
|
|
||||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
|
||||||
elif ep_size < 16 or with_prefill:
|
|
||||||
return FusedMoEState.All2All
|
|
||||||
else:
|
|
||||||
return FusedMoEState.MC2
|
|
||||||
|
|
||||||
|
|
||||||
def register_ascend_customop():
|
def register_ascend_customop():
|
||||||
"""Register Ascend CustomOP
|
"""Register Ascend CustomOP
|
||||||
|
|
||||||
@@ -506,3 +477,30 @@ def register_ascend_customop():
|
|||||||
|
|
||||||
# NOTE: Keep this at last to ensure all custom actions are registered
|
# NOTE: Keep this at last to ensure all custom actions are registered
|
||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
|
||||||
|
class AscendSocVersion(Enum):
|
||||||
|
A2 = 0
|
||||||
|
A3 = 1
|
||||||
|
UNDEFINED = 2
|
||||||
|
|
||||||
|
|
||||||
|
_ascend_soc_version = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_ascend_soc_version():
|
||||||
|
soc_version = torch_npu.npu.get_soc_version()
|
||||||
|
global _ascend_soc_version
|
||||||
|
if 220 <= soc_version <= 225:
|
||||||
|
_ascend_soc_version = AscendSocVersion.A2
|
||||||
|
elif 250 <= soc_version <= 255:
|
||||||
|
_ascend_soc_version = AscendSocVersion.A3
|
||||||
|
else:
|
||||||
|
_ascend_soc_version = AscendSocVersion.UNDEFINED
|
||||||
|
|
||||||
|
|
||||||
|
def get_ascend_soc_version():
|
||||||
|
global _ascend_soc_version
|
||||||
|
assert _ascend_soc_version is not None
|
||||||
|
return _ascend_soc_version
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
|
||||||
@@ -142,9 +142,9 @@ class EagleProposer:
|
|||||||
self.positions[:num_tokens] = target_positions.to(device)
|
self.positions[:num_tokens] = target_positions.to(device)
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
attn_metadata.block_tables = block_table.to(device)
|
attn_metadata.block_tables = block_table.to(device)
|
||||||
with set_forward_context(attn_metadata,
|
with set_ascend_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_tokens):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
@@ -239,9 +239,9 @@ class EagleProposer:
|
|||||||
attn_metadata.attn_mask = attn_mask
|
attn_metadata.attn_mask = attn_mask
|
||||||
attn_metadata.block_tables = block_table.to(device)
|
attn_metadata.block_tables = block_table.to(device)
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(attn_metadata,
|
with set_ascend_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size):
|
num_tokens=input_batch_size):
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:input_batch_size],
|
input_ids=self.input_ids[:input_batch_size],
|
||||||
@@ -344,8 +344,9 @@ class EagleProposer:
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_ascend_forward_context(None,
|
||||||
num_tokens=num_tokens):
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.input_ids[:num_tokens],
|
input_ids=self.input_ids[:num_tokens],
|
||||||
positions=self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ import torch
|
|||||||
import torch._dynamo.cache_size
|
import torch._dynamo.cache_size
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ReduceOp
|
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
|||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
torch._logging.set_logs(
|
torch._logging.set_logs(
|
||||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||||
|
|
||||||
|
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||||
|
self.in_profile_run = False
|
||||||
|
|
||||||
# kv role
|
# kv role
|
||||||
self.is_kv_producer = False
|
self.is_kv_producer = False
|
||||||
if vllm_config.kv_transfer_config is not None:
|
if vllm_config.kv_transfer_config is not None:
|
||||||
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _get_forward_metadata_across_dp(
|
def _get_forward_metadata_across_dp(
|
||||||
self, total_num_scheduled_tokens: int,
|
self,
|
||||||
with_prefill: bool) -> tuple[int, bool]:
|
maybe_padded_num_tokens: int,
|
||||||
forward_metadata = torch.tensor(
|
num_tokens: int,
|
||||||
[total_num_scheduled_tokens, with_prefill],
|
with_prefill: bool,
|
||||||
device="cpu",
|
enable_dbo: bool = False,
|
||||||
dtype=torch.int32)
|
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||||
dist.all_reduce(forward_metadata,
|
if self.dp_size == 1:
|
||||||
op=ReduceOp.MAX,
|
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||||
group=get_dp_group().cpu_group)
|
|
||||||
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
num_tokens_across_dp = [0] * self.dp_size * 2
|
||||||
|
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
|
||||||
|
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
|
||||||
|
forward_metadata = torch.tensor(num_tokens_across_dp +
|
||||||
|
[with_prefill, not enable_dbo],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
|
||||||
|
with_prefill = bool(forward_metadata[-2])
|
||||||
|
|
||||||
|
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
|
||||||
|
if with_prefill:
|
||||||
|
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
|
||||||
|
2]
|
||||||
|
maybe_padded_num_tokens = num_tokens
|
||||||
|
else:
|
||||||
|
num_tokens_across_dp = forward_metadata[:self.dp_size]
|
||||||
|
|
||||||
|
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
|
||||||
|
# `max_tokens_across_dp`, in other situation it is not necessary.
|
||||||
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
|
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
|
||||||
|
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
|
||||||
|
self.dp_size,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||||
|
forward_metadata[-1])
|
||||||
|
|
||||||
def get_eagle_atten_dict(
|
def get_eagle_atten_dict(
|
||||||
self,
|
self,
|
||||||
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.dp_size > 1:
|
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
|
||||||
total_num_scheduled_tokens, with_prefill)
|
|
||||||
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
|
||||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
|
||||||
|
|
||||||
# Add graph_pad_size here
|
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
if self.dp_size > 1:
|
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
total_num_scheduled_tokens)
|
||||||
max_num_tokens)
|
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||||
else:
|
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
|
||||||
total_num_scheduled_tokens)
|
|
||||||
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
|
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||||
|
|
||||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||||
|
|
||||||
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
positions = self.mrope_positions[:, :num_input_tokens]
|
positions = self.mrope_positions[:, :num_input_tokens]
|
||||||
|
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
input_ids = self.input_ids[:padded_batch_size]
|
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||||
positions = self.positions[:padded_batch_size]
|
positions = self.positions[:padded_num_tokens_across_dp]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata,
|
with set_ascend_forward_context(
|
||||||
self.vllm_config,
|
attn_metadata,
|
||||||
num_tokens=num_input_tokens):
|
self.vllm_config,
|
||||||
|
num_tokens=padded_num_tokens_across_dp,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
with_prefill=with_prefill,
|
||||||
|
num_actual_tokens=total_num_scheduled_tokens):
|
||||||
with ProfileExecuteDuration().capture_async("forward"):
|
with ProfileExecuteDuration().capture_async("forward"):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
padded_batch_size)
|
padded_num_tokens_across_dp)
|
||||||
hidden_states = compiled_model(
|
hidden_states = compiled_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def kv_connector_no_forward(
|
def kv_connector_no_forward(
|
||||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_ascend_forward_context(None, self.vllm_config):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
finished_sending, finished_recving = (
|
finished_sending, finished_recving = (
|
||||||
self.get_finished_kv_transfer(scheduler_output))
|
self.get_finished_kv_transfer(scheduler_output))
|
||||||
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
is_compile: bool = False,
|
skip_attn: bool = True,
|
||||||
with_prefill: bool = True,
|
with_prefill: bool = False,
|
||||||
|
is_torchair_compile: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
maybe_padded_num_tokens = num_tokens
|
||||||
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
|
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||||
|
num_tokens)
|
||||||
|
|
||||||
|
# Padding for DP
|
||||||
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||||
|
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||||
|
maybe_padded_num_tokens, num_tokens, with_prefill, False)
|
||||||
|
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
# has num_tokens in total.
|
# has num_tokens in total.
|
||||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||||
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
num_reqs = min(num_tokens, max_num_reqs)
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||||
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.is_kv_producer:
|
if self.is_kv_producer:
|
||||||
with_prefill = True
|
with_prefill = True
|
||||||
|
|
||||||
|
# NOTE: If torchair graph mode and not with_prefill,
|
||||||
|
# we can't skip_attn, it will cause graph recompile.
|
||||||
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
|
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||||
|
num_reqs=num_tokens, num_actual_tokens=1)
|
||||||
|
elif skip_attn:
|
||||||
|
attn_metadata = None
|
||||||
|
else:
|
||||||
|
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||||
|
attn_metadata = None
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
model = self.model
|
model = self.model
|
||||||
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
with set_forward_context(None,
|
with set_ascend_forward_context(
|
||||||
self.vllm_config,
|
attn_metadata,
|
||||||
num_tokens=num_tokens):
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
with_prefill=with_prefill,
|
||||||
|
in_profile_run=self.in_profile_run,
|
||||||
|
num_actual_tokens=0,
|
||||||
|
):
|
||||||
|
model_kwargs = {}
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
|
||||||
num_reqs=num_tokens, num_actual_tokens=1)
|
|
||||||
# Only mark static while compiling
|
# Only mark static while compiling
|
||||||
if is_compile:
|
if is_torchair_compile:
|
||||||
torch._dynamo.mark_static(input_ids)
|
torch._dynamo.mark_static(input_ids)
|
||||||
torch._dynamo.mark_static(positions)
|
torch._dynamo.mark_static(positions)
|
||||||
torch._dynamo.mark_static(
|
torch._dynamo.mark_static(
|
||||||
attn_metadata.decode.block_table)
|
attn_metadata.decode.block_table)
|
||||||
torch._dynamo.mark_static(
|
torch._dynamo.mark_static(
|
||||||
attn_metadata.decode.input_positions)
|
attn_metadata.decode.input_positions)
|
||||||
|
torch._dynamo.mark_static(
|
||||||
|
get_forward_context().mc2_mask)
|
||||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||||
for kv in self.kv_caches:
|
for kv in self.kv_caches:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
num_tokens)
|
num_tokens)
|
||||||
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
hidden_states = compiled_model(
|
hidden_states = compiled_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
kv_caches=self.kv_caches,
|
**model_kwargs,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
maybe_converting_weight_acl_format(self.model,
|
maybe_converting_weight_acl_format(self.model,
|
||||||
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.drafter.dummy_run(num_tokens)
|
self.drafter.dummy_run(num_tokens)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_in_profile_run(self):
|
||||||
|
self.in_profile_run = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.in_profile_run = False
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
with self.set_in_profile_run():
|
||||||
|
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||||
|
with_prefill=True)
|
||||||
output = None
|
output = None
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
if self.is_pooling_model:
|
if self.is_pooling_model:
|
||||||
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(num_tokens,
|
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||||
is_compile=True,
|
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||||
with_prefill=False)
|
|
||||||
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
|
|
||||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||||
|
|
||||||
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Trigger ACL graph capture for specific shapes.
|
# Trigger ACL graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||||
with graph_capture(device=self.device):
|
with graph_capture(device=self.device):
|
||||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ import torch
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||||
set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
from vllm.forward_context import set_forward_context
|
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
process_weights_after_loading, set_default_torch_dtype)
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ class MtpProposer:
|
|||||||
query_start_loc=cu_num_tokens,
|
query_start_loc=cu_num_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_ascend_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=target_positions,
|
positions=target_positions,
|
||||||
|
|||||||
@@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||||
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib,
|
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
|
||||||
vllm_version_is)
|
try_register_lib, vllm_version_is)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
if not vllm_version_is("0.10.0"):
|
if not vllm_version_is("0.10.0"):
|
||||||
@@ -134,6 +135,7 @@ class NPUWorker(WorkerBase):
|
|||||||
NPUPlatform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||||
|
|
||||||
|
init_ascend_soc_version()
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
self._init_worker_distributed_environment()
|
self._init_worker_distributed_environment()
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
@@ -272,20 +274,8 @@ class NPUWorker(WorkerBase):
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
return self.model_runner.pin_lora(lora_id)
|
return self.model_runner.pin_lora(lora_id)
|
||||||
|
|
||||||
def _get_max_num_tokens_and_with_prefill(self):
|
|
||||||
max_num_tokens = 1
|
|
||||||
with_prefill = False
|
|
||||||
if self.model_runner.dp_size > 1:
|
|
||||||
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
|
|
||||||
max_num_tokens, with_prefill)
|
|
||||||
return max_num_tokens, with_prefill
|
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
|
self.model_runner._dummy_run(1)
|
||||||
)
|
|
||||||
self.model_runner._dummy_run(max_num_tokens,
|
|
||||||
is_compile=False,
|
|
||||||
with_prefill=with_prefill)
|
|
||||||
|
|
||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
@@ -295,6 +285,7 @@ class NPUWorker(WorkerBase):
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
self.parallel_config.tensor_parallel_size,
|
self.parallel_config.tensor_parallel_size,
|
||||||
self.parallel_config.pipeline_parallel_size)
|
self.parallel_config.pipeline_parallel_size)
|
||||||
|
init_ascend_model_parallel(self.parallel_config)
|
||||||
ensure_kv_transfer_initialized(self.vllm_config)
|
ensure_kv_transfer_initialized(self.vllm_config)
|
||||||
|
|
||||||
def _init_profiler(self):
|
def _init_profiler(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user