[refactor] Remove unnecessary attributes from set_ascend_forward_context (#5204)
### What this PR does / why we need it?
Remove unnecessary attributes from set_ascend_forward_context
1.prefetch_stream
2.weight_prefetch_method
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -286,8 +286,8 @@ def test_select_experts(
|
|||||||
|
|
||||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
||||||
) as mock_native_grouped_topk, \
|
) as mock_native_grouped_topk, \
|
||||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||||
return_value=MagicMock(weight_prefetch_method=MagicMock())):
|
return_value=MagicMock()):
|
||||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||||
x)
|
x)
|
||||||
|
|
||||||
@@ -323,8 +323,8 @@ def test_select_experts(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICE)
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
def test_select_experts_invalid_scoring_func(device: str):
|
def test_select_experts_invalid_scoring_func(device: str):
|
||||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||||
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
|
return_value=MagicMock()), \
|
||||||
pytest.raises(ValueError,
|
pytest.raises(ValueError,
|
||||||
match="Unsupported scoring function: invalid"):
|
match="Unsupported scoring function: invalid"):
|
||||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||||
@@ -336,17 +336,3 @@ def test_select_experts_invalid_scoring_func(device: str):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICE)
|
|
||||||
def test_select_experts_missing_group_params(device: str):
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
|
||||||
router_logits=torch.randn(1, 64, device=device),
|
|
||||||
top_k=2,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
renormalize=False,
|
|
||||||
scoring_func="softmax")
|
|
||||||
gc.collect()
|
|
||||||
torch.npu.empty_cache()
|
|
||||||
torch.npu.reset_peak_memory_stats()
|
|
||||||
|
|||||||
@@ -92,15 +92,14 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||||
mock_weight_prefetch_method = MagicMock()
|
mock_weight_prefetch_method = MagicMock()
|
||||||
mock_forward_context_obj = MagicMock(
|
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||||
moe_comm_method=mock_moe_comm_method,
|
|
||||||
moe_comm_type=MoECommType.MC2,
|
moe_comm_type=MoECommType.MC2,
|
||||||
max_tokens_across_dp=10,
|
max_tokens_across_dp=10,
|
||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata,
|
||||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
mc2_mask=torch.zeros(
|
||||||
|
16, dtype=torch.bool),
|
||||||
padded_num_tokens=16,
|
padded_num_tokens=16,
|
||||||
with_quant=False,
|
with_quant=False)
|
||||||
weight_prefetch_method=mock_weight_prefetch_method)
|
|
||||||
|
|
||||||
with patch('torch.distributed.get_rank', return_value=0), \
|
with patch('torch.distributed.get_rank', return_value=0), \
|
||||||
patch('torch.distributed.get_world_size', return_value=4), \
|
patch('torch.distributed.get_world_size', return_value=4), \
|
||||||
@@ -133,8 +132,8 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
return_value=None), \
|
return_value=None), \
|
||||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||||
return_value=None), \
|
return_value=None), \
|
||||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||||
return_value=mock_forward_context_obj):
|
return_value=mock_weight_prefetch_method):
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
'mock_forward_context_obj': mock_forward_context_obj,
|
'mock_forward_context_obj': mock_forward_context_obj,
|
||||||
|
|||||||
@@ -63,21 +63,18 @@ class TestAscendW8A8LinearMethod(TestBase):
|
|||||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||||
|
|
||||||
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
|
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
|
||||||
@patch("torch.ops.vllm.quantize")
|
@patch("torch.ops.vllm.quantize")
|
||||||
@patch("torch_npu.npu_quant_matmul")
|
@patch("torch_npu.npu_quant_matmul")
|
||||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
||||||
mock_get_forward_context):
|
mock_get_weight_prefetch_method):
|
||||||
layer = MagicMock()
|
layer = MagicMock()
|
||||||
layer.aclnn_input_scale = 0.1
|
layer.aclnn_input_scale = 0.1
|
||||||
layer.aclnn_input_offset = 0.2
|
layer.aclnn_input_offset = 0.2
|
||||||
layer.weight = torch.randn(128, 256)
|
layer.weight = torch.randn(128, 256)
|
||||||
layer.deq_scale = 0.3
|
layer.deq_scale = 0.3
|
||||||
|
|
||||||
mock_forward_context = MagicMock()
|
mock_get_weight_prefetch_method.return_value = MagicMock()
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
|
||||||
mock_weight_prefetch_method = MagicMock()
|
|
||||||
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
|
|
||||||
|
|
||||||
x = torch.randn(32, 128)
|
x = torch.randn(32, 128)
|
||||||
bias = torch.randn(256)
|
bias = torch.randn(256)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
@@ -16,11 +16,6 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
|||||||
get_ascend_device_type, has_layer_idx,
|
get_ascend_device_type, has_layer_idx,
|
||||||
is_moe_model)
|
is_moe_model)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|
||||||
else:
|
|
||||||
WeightPrefetchMethod = None
|
|
||||||
|
|
||||||
|
|
||||||
class MoECommType(Enum):
|
class MoECommType(Enum):
|
||||||
ALLGATHER = 0
|
ALLGATHER = 0
|
||||||
@@ -41,9 +36,7 @@ def set_ascend_forward_context(
|
|||||||
num_actual_tokens: Optional[int] = None,
|
num_actual_tokens: Optional[int] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||||
prefetch_stream: torch.npu.Stream = None,
|
|
||||||
model_instance: torch.nn.Module = None,
|
model_instance: torch.nn.Module = None,
|
||||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
|
|
||||||
is_mtp_model=False):
|
is_mtp_model=False):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
@@ -116,13 +109,10 @@ def set_ascend_forward_context(
|
|||||||
forward_context.layer_idx is not None and \
|
forward_context.layer_idx is not None and \
|
||||||
num_tokens is not None and num_tokens < 500
|
num_tokens is not None and num_tokens < 500
|
||||||
if prefetch_mlp_enabled:
|
if prefetch_mlp_enabled:
|
||||||
forward_context.prefetch_stream = prefetch_stream
|
|
||||||
forward_context.model_instance = model_instance
|
|
||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
forward_context.prefetch_mlp_gate_up_proj = False
|
||||||
forward_context.prefetch_mlp_down_proj = False
|
forward_context.prefetch_mlp_down_proj = False
|
||||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||||
forward_context.model_instance = model_instance
|
forward_context.model_instance = model_instance
|
||||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
|
||||||
forward_context.is_mtp_model = is_mtp_model
|
forward_context.is_mtp_model = is_mtp_model
|
||||||
|
|
||||||
if num_tokens is None and attn_metadata is not None:
|
if num_tokens is None and attn_metadata is not None:
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
from vllm_ascend.utils import get_weight_prefetch_method
|
||||||
|
|
||||||
|
|
||||||
def select_experts(hidden_states: torch.Tensor,
|
def select_experts(hidden_states: torch.Tensor,
|
||||||
@@ -56,7 +57,7 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||||
"""
|
"""
|
||||||
# prefetch w1_w3_proj.weight preprocess
|
# prefetch w1_w3_proj.weight preprocess
|
||||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
weight_prefetch_method = get_weight_prefetch_method()
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||||
hidden_states, "gate_up")
|
hidden_states, "gate_up")
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ from vllm.triton_utils import HAS_TRITON
|
|||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||||
enable_custom_op, get_ascend_device_type)
|
enable_custom_op, get_ascend_device_type,
|
||||||
|
get_weight_prefetch_method)
|
||||||
|
|
||||||
|
|
||||||
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||||
@@ -100,7 +101,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
_output_dtype = w2_scale[0].dtype
|
_output_dtype = w2_scale[0].dtype
|
||||||
|
|
||||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
weight_prefetch_method = get_weight_prefetch_method()
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||||
hidden_states)
|
hidden_states)
|
||||||
|
|||||||
@@ -119,16 +119,16 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
|||||||
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
|
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
|
||||||
return
|
return
|
||||||
model_instance = forward_context.model_instance
|
model_instance = forward_context.model_instance
|
||||||
prefetch_stream = forward_context.prefetch_stream
|
weight_prefetch_stream = prefetch_stream()
|
||||||
layer_idx = int(prefix.split('.')[2])
|
layer_idx = int(prefix.split('.')[2])
|
||||||
|
|
||||||
# start point of gate_up_proj weight prefetch
|
# start point of gate_up_proj weight prefetch
|
||||||
if prefix.split('.')[-2] == "self_attn":
|
if prefix.split('.')[-2] == "self_attn":
|
||||||
forward_context.prefetch_mlp_gate_up_proj = True
|
forward_context.prefetch_mlp_gate_up_proj = True
|
||||||
if forward_context.prefetch_mlp_gate_up_proj:
|
if forward_context.prefetch_mlp_gate_up_proj:
|
||||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(weight_prefetch_stream):
|
||||||
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(
|
torch_npu.npu_prefetch(
|
||||||
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
|
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
|
||||||
@@ -178,13 +178,13 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
|||||||
return
|
return
|
||||||
forward_context.prefetch_mlp_down_proj = True
|
forward_context.prefetch_mlp_down_proj = True
|
||||||
model_instance = forward_context.model_instance
|
model_instance = forward_context.model_instance
|
||||||
prefetch_stream = forward_context.prefetch_stream
|
weight_prefetch_stream = prefetch_stream()
|
||||||
layer_idx = forward_context.layer_idx
|
layer_idx = forward_context.layer_idx
|
||||||
|
|
||||||
# start point of down_proj weight prefetch
|
# start point of down_proj weight prefetch
|
||||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
|
||||||
with torch.npu.stream(prefetch_stream):
|
with torch.npu.stream(weight_prefetch_stream):
|
||||||
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
|
||||||
torch_npu.npu_prefetch(
|
torch_npu.npu_prefetch(
|
||||||
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
|
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
|
||||||
@@ -208,9 +208,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
|||||||
return
|
return
|
||||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||||
forward_context.prefetch_mlp_down_proj:
|
forward_context.prefetch_mlp_down_proj:
|
||||||
prefetch_stream = forward_context.prefetch_stream
|
weight_prefetch_stream = prefetch_stream()
|
||||||
# wait until prefetch done
|
# wait until prefetch done
|
||||||
torch.npu.current_stream().wait_stream(prefetch_stream)
|
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
|
||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
forward_context.prefetch_mlp_gate_up_proj = False
|
||||||
forward_context.prefetch_mlp_down_proj = False
|
forward_context.prefetch_mlp_down_proj = False
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
|
|
||||||
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||||
get_ascend_device_type, maybe_trans_nz)
|
get_ascend_device_type,
|
||||||
|
get_weight_prefetch_method, maybe_trans_nz)
|
||||||
|
|
||||||
|
|
||||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||||
@@ -98,12 +98,7 @@ class AscendW8A8LinearMethod:
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if x.dtype != torch.int8:
|
if x.dtype != torch.int8:
|
||||||
layer_cls_name = layer.__class__.__name__
|
layer_cls_name = layer.__class__.__name__
|
||||||
try:
|
weight_prefetch_method = get_weight_prefetch_method()
|
||||||
weight_prefetch_method = get_forward_context(
|
|
||||||
).weight_prefetch_method
|
|
||||||
except AssertionError:
|
|
||||||
weight_prefetch_method = None
|
|
||||||
|
|
||||||
# prefetch qkvo_proj.weight preprocess
|
# prefetch qkvo_proj.weight preprocess
|
||||||
if weight_prefetch_method:
|
if weight_prefetch_method:
|
||||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from vllm.logger import logger
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import WeightPrefetchConfig, get_ascend_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@@ -52,6 +52,7 @@ ACL_FORMAT_FRACTAL_NZ = 29
|
|||||||
_CUSTOM_OP_ENABLED = None
|
_CUSTOM_OP_ENABLED = None
|
||||||
_CURRENT_STREAM = None
|
_CURRENT_STREAM = None
|
||||||
_PREFETCH_STREAM = None
|
_PREFETCH_STREAM = None
|
||||||
|
_WEIGHT_PREFETCH_METHOD = None
|
||||||
_GLOBAL_STREAM = None
|
_GLOBAL_STREAM = None
|
||||||
_SHARED_EXPERTS_CALCULATION_STREAM = None
|
_SHARED_EXPERTS_CALCULATION_STREAM = None
|
||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||||
@@ -309,6 +310,18 @@ def prefetch_stream() -> torch.npu.Stream:
|
|||||||
return _PREFETCH_STREAM
|
return _PREFETCH_STREAM
|
||||||
|
|
||||||
|
|
||||||
|
def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
|
||||||
|
global _WEIGHT_PREFETCH_METHOD
|
||||||
|
if _WEIGHT_PREFETCH_METHOD is None:
|
||||||
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||||
|
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
|
||||||
|
return _WEIGHT_PREFETCH_METHOD
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_prefetch_method():
|
||||||
|
return _WEIGHT_PREFETCH_METHOD
|
||||||
|
|
||||||
|
|
||||||
def global_stream() -> torch.npu.Stream:
|
def global_stream() -> torch.npu.Stream:
|
||||||
global _GLOBAL_STREAM
|
global _GLOBAL_STREAM
|
||||||
if _GLOBAL_STREAM is None:
|
if _GLOBAL_STREAM is None:
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput,
|
|||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||||
from vllm.v1.worker.utils import AttentionGroup
|
from vllm.v1.worker.utils import AttentionGroup
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
@@ -106,7 +105,6 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
|||||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||||
from vllm_ascend.eplb.utils import model_register
|
from vllm_ascend.eplb.utils import model_register
|
||||||
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
||||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|
||||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||||
from vllm_ascend.sample.sampler import AscendSampler
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
@@ -115,7 +113,8 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
|||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
|
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
|
||||||
enable_sp, get_ascend_device_type, is_moe_model,
|
enable_sp, get_ascend_device_type, is_moe_model,
|
||||||
lmhead_tp_enable, maybe_trans_nz)
|
lmhead_tp_enable, maybe_trans_nz,
|
||||||
|
set_weight_prefetch_method)
|
||||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
||||||
@@ -209,18 +208,13 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.pcp_rank = 0
|
self.pcp_rank = 0
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
|
||||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
|
||||||
else:
|
|
||||||
self.prefetch_stream = None
|
|
||||||
self.sampler = AscendSampler()
|
self.sampler = AscendSampler()
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
self.attn_state = None
|
self.attn_state = None
|
||||||
|
|
||||||
# Ascend-specific configurations
|
# Ascend-specific configurations
|
||||||
self.ascend_config = get_ascend_config()
|
self.ascend_config = get_ascend_config()
|
||||||
self.weight_prefetch_method = WeightPrefetchMethod(
|
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
|
||||||
self.ascend_config.weight_prefetch_config)
|
|
||||||
# Dump / PrecisionDebugger configuration now comes from AscendConfig
|
# Dump / PrecisionDebugger configuration now comes from AscendConfig
|
||||||
dump_cfg = self.ascend_config.dump_config
|
dump_cfg = self.ascend_config.dump_config
|
||||||
self.dump_enable = dump_cfg.enable_dump
|
self.dump_enable = dump_cfg.enable_dump
|
||||||
@@ -1420,9 +1414,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
num_actual_tokens=scheduler_output.
|
num_actual_tokens=scheduler_output.
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens,
|
||||||
prefetch_stream=self.prefetch_stream,
|
model_instance=self.model):
|
||||||
model_instance=self.model,
|
|
||||||
weight_prefetch_method=self.weight_prefetch_method):
|
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
|
||||||
hidden_states = self._generate_process_reqs_hidden_states(
|
hidden_states = self._generate_process_reqs_hidden_states(
|
||||||
@@ -2133,9 +2125,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
prefetch_stream=self.prefetch_stream,
|
model_instance=self.model):
|
||||||
model_instance=self.model,
|
|
||||||
weight_prefetch_method=self.weight_prefetch_method):
|
|
||||||
hidden_states = self._generate_dummy_run_hidden_states(
|
hidden_states = self._generate_dummy_run_hidden_states(
|
||||||
input_ids, positions, num_tokens_padded,
|
input_ids, positions, num_tokens_padded,
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|||||||
Reference in New Issue
Block a user