[refactor] refactor model runner capture model (#5230)
### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.
Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -54,10 +54,10 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
|||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
original_capture = NPUModelRunner._capture_model
|
original_capture = NPUModelRunner.capture_model
|
||||||
|
|
||||||
with patch.object(NPUModelRunner,
|
with patch.object(NPUModelRunner,
|
||||||
'_capture_model',
|
'capture_model',
|
||||||
new=capture_model_wrapper(original_capture)):
|
new=capture_model_wrapper(original_capture)):
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is", "The president of the United States is",
|
"Hello, my name is", "The president of the United States is",
|
||||||
@@ -73,7 +73,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
|||||||
vllm_model = VllmRunner(snapshot_download(model))
|
vllm_model = VllmRunner(snapshot_download(model))
|
||||||
_ = vllm_model.generate(prompts, sampling_params)
|
_ = vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
assert capture_called.value == 1, "_capture_model was not called during test"
|
assert capture_called.value == 1, "capture_model was not called during test"
|
||||||
assert capture_mem_before.value != -1, "capture_mem_before not set"
|
assert capture_mem_before.value != -1, "capture_mem_before not set"
|
||||||
assert capture_mem_after.value != -1, "capture_mem_after not set"
|
assert capture_mem_after.value != -1, "capture_mem_after not set"
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
|||||||
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
|
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
|
||||||
max_mem_expected = max_capture_mem_gib * (1024**3)
|
max_mem_expected = max_capture_mem_gib * (1024**3)
|
||||||
assert mem_used_by_capture < max_mem_expected, (
|
assert mem_used_by_capture < max_mem_expected, (
|
||||||
f"_capture_model used more memory than expected. "
|
f"capture_model used more memory than expected. "
|
||||||
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
|
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
|
||||||
f"Expected: < {max_capture_mem_gib:.2f} GiB")
|
f"Expected: < {max_capture_mem_gib:.2f} GiB")
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import sys
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
@@ -98,7 +97,6 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
||||||
assert builder.device == device
|
assert builder.device == device
|
||||||
assert builder.vllm_config == vllm_config
|
assert builder.vllm_config == vllm_config
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,6 @@ from vllm_ascend.utils import weak_ref_tensors
|
|||||||
|
|
||||||
|
|
||||||
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.ALWAYS
|
|
||||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
@@ -72,6 +69,16 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
) if self.dcp_size > 1 else 0
|
) if self.dcp_size > 1 else 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AscendAttentionCPMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.ALWAYS
|
||||||
|
|
||||||
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
|
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
|
||||||
"""
|
"""
|
||||||
given 4-d list [req][pcp][dcp], return:
|
given 4-d list [req][pcp][dcp], return:
|
||||||
|
|||||||
@@ -182,9 +182,6 @@ class AscendMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.ALWAYS
|
|
||||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
@@ -220,6 +217,16 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AscendAttentionMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.ALWAYS
|
||||||
|
|
||||||
def reorder_batch(self, input_batch,
|
def reorder_batch(self, input_batch,
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import ClassVar, Optional, Tuple, TypeVar
|
from typing import Optional, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -12,7 +12,7 @@ from vllm.distributed import (get_dcp_group,
|
|||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||||
@@ -37,9 +37,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
|||||||
|
|
||||||
|
|
||||||
class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.UNIFORM_BATCH
|
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@@ -74,6 +71,16 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AscendMlaCPMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
def set_num_actual_tokens(
|
def set_num_actual_tokens(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
|
||||||
TypeVar)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -15,7 +14,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|||||||
from vllm.utils.math_utils import cdiv, round_down
|
from vllm.utils.math_utils import cdiv, round_down
|
||||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
@@ -182,9 +181,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
|||||||
|
|
||||||
|
|
||||||
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.UNIFORM_BATCH
|
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@@ -263,6 +259,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
self.query_lens: torch.Tensor = None
|
self.query_lens: torch.Tensor = None
|
||||||
self.seq_lens: torch.Tensor = None
|
self.seq_lens: torch.Tensor = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AscendMLAMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
# We now want to reorder the batch so that the "decode" requests are at
|
# We now want to reorder the batch so that the "decode" requests are at
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
|
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
|||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
@@ -113,9 +114,6 @@ M = TypeVar("M", bound=AscendSFAMetadata)
|
|||||||
|
|
||||||
|
|
||||||
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@@ -159,6 +157,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
== CUDAGraphMode.FULL_DECODE_ONLY
|
== CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AscendSFAMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
# Explicit override in case the underlying builder specialized this getter.
|
||||||
|
# @override omitted only because of mypy limitation due to type variable.
|
||||||
|
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
# No need to reorder for Ascend SFA
|
# No need to reorder for Ascend SFA
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from vllm.platforms import Platform, PlatformEnum
|
|||||||
|
|
||||||
# todo: please remove it when solve cuda hard code in vllm
|
# todo: please remove it when solve cuda hard code in vllm
|
||||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||||
|
# todo: please remove it when support controls garbage collection during CUDA graph capture.
|
||||||
|
os.environ["VLLM_ENABLE_CUDAGRAPH_GC"] = "1"
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.utils import refresh_block_size
|
from vllm_ascend.utils import refresh_block_size
|
||||||
@@ -244,6 +246,12 @@ class NPUPlatform(Platform):
|
|||||||
data_parallel_size,
|
data_parallel_size,
|
||||||
)
|
)
|
||||||
compilation_config.use_inductor = False
|
compilation_config.use_inductor = False
|
||||||
|
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
|
||||||
|
# Since the process is created in the spawn mode, the value of the class attribute
|
||||||
|
# attention ops transmitted is still the one before modification, so it has not been modified.
|
||||||
|
# This will cause in scenarios where both piecewise and splitting ops are configured simultaneously,
|
||||||
|
# If splitting ops does not contain the vllm::mla forward value, this configuration issue will
|
||||||
|
# not be detected in advance assert.
|
||||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
ascend_config.enable_npugraph_ex = False
|
ascend_config.enable_npugraph_ex = False
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
@@ -27,16 +27,12 @@ from multiprocessing import Manager
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import regex as re
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm # type: ignore
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||||
from vllm.attention.layer import Attention, MLAAttention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.compilation.counter import compilation_counter
|
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|
||||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
@@ -46,8 +42,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||||||
get_pcp_group, get_pp_group,
|
get_pcp_group, get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group)
|
||||||
is_global_first_rank)
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
@@ -58,8 +53,7 @@ from vllm.utils.import_utils import LazyLoader
|
|||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
CommonAttentionMetadata)
|
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
EncoderOnlyAttentionSpec,
|
EncoderOnlyAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
@@ -1972,13 +1966,18 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||||
force_attention: bool = False,
|
force_attention: bool = False,
|
||||||
uniform_decode: bool = False,
|
uniform_decode: bool = False,
|
||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
|
allow_microbatching: bool = True,
|
||||||
|
skip_eplb: bool = False,
|
||||||
|
remove_lora: bool = True,
|
||||||
|
activate_lora: bool = False,
|
||||||
|
is_graph_capturing: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# only support eager mode and piecewise graph now
|
# only support eager mode and piecewise graph now
|
||||||
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
|
||||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||||
}
|
}
|
||||||
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
|
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
|
||||||
@@ -2054,15 +2053,15 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
|
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
|
||||||
|
|
||||||
# filter out the valid batch descriptor
|
# filter out the valid batch descriptor
|
||||||
if aclgraph_runtime_mode is not None:
|
if cudagraph_runtime_mode is not None:
|
||||||
# we allow forcing NONE when the dispatcher disagrees to support
|
# we allow forcing NONE when the dispatcher disagrees to support
|
||||||
# warm ups for aclgraph capture
|
# warm ups for aclgraph capture
|
||||||
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode:
|
if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Aclgraph runtime mode mismatch at dummy_run. "
|
f"Aclgraph runtime mode mismatch at dummy_run. "
|
||||||
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}.")
|
||||||
else:
|
else:
|
||||||
aclgraph_runtime_mode = _ag_mode
|
cudagraph_runtime_mode = _ag_mode
|
||||||
|
|
||||||
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
|
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
|
||||||
# and not supported in ASCEND now. We could remove it in the future.
|
# and not supported in ASCEND now. We could remove it in the future.
|
||||||
@@ -2071,7 +2070,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_reqs=num_reqs_padded,
|
num_reqs=num_reqs_padded,
|
||||||
num_tokens=num_tokens_padded,
|
num_tokens=num_tokens_padded,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
force_attention=force_attention,
|
force_attention=force_attention,
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
)
|
)
|
||||||
@@ -2147,7 +2146,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
in_profile_run=is_profile,
|
in_profile_run=is_profile,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
model_instance=self.model):
|
model_instance=self.model):
|
||||||
hidden_states = self._generate_dummy_run_hidden_states(
|
hidden_states = self._generate_dummy_run_hidden_states(
|
||||||
@@ -2161,7 +2160,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_reqs=num_reqs_padded,
|
num_reqs=num_reqs_padded,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
dummy_compute_logits=dummy_drafter_compute_logits,
|
dummy_compute_logits=dummy_drafter_compute_logits,
|
||||||
in_graph_capturing=not force_attention,
|
in_graph_capturing=not force_attention,
|
||||||
@@ -2677,7 +2676,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
def get_attn_backends_for_group(
|
def get_attn_backends_for_group(
|
||||||
kv_cache_group_spec: KVCacheGroupSpec,
|
kv_cache_group_spec: KVCacheGroupSpec,
|
||||||
) -> dict[AttentionGroupKey, list[str]]:
|
) -> tuple[dict[AttentionGroupKey, list[str]],
|
||||||
|
set[type[AttentionBackend]]]:
|
||||||
layers = get_layers_from_vllm_config(
|
layers = get_layers_from_vllm_config(
|
||||||
self.vllm_config, AttentionLayerBase,
|
self.vllm_config, AttentionLayerBase,
|
||||||
kv_cache_group_spec.layer_names)
|
kv_cache_group_spec.layer_names)
|
||||||
@@ -2699,10 +2699,14 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
attn_backends[key] = AttentionGroupKey(attn_backend,
|
attn_backends[key] = AttentionGroupKey(attn_backend,
|
||||||
layer_kv_cache_spec)
|
layer_kv_cache_spec)
|
||||||
attn_backend_layers[key].append(layer_name)
|
attn_backend_layers[key].append(layer_name)
|
||||||
return {
|
return (
|
||||||
attn_backends[k]: v
|
{
|
||||||
for k, v in attn_backend_layers.items()
|
attn_backends[k]: v
|
||||||
}
|
for k, v in attn_backend_layers.items()
|
||||||
|
},
|
||||||
|
set(group_key.attn_backend
|
||||||
|
for group_key in attn_backends.values()),
|
||||||
|
)
|
||||||
|
|
||||||
def create_attn_groups(attn_backends_map: dict[AttentionBackend,
|
def create_attn_groups(attn_backends_map: dict[AttentionBackend,
|
||||||
list[str]],
|
list[str]],
|
||||||
@@ -2723,11 +2727,21 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
attn_groups.append(attn_group)
|
attn_groups.append(attn_group)
|
||||||
return attn_groups
|
return attn_groups
|
||||||
|
|
||||||
|
attention_backend_maps = []
|
||||||
|
attention_backend_list = []
|
||||||
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||||
|
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
||||||
|
attention_backend_maps.append(attn_backends[0])
|
||||||
|
attention_backend_list.append(attn_backends[1])
|
||||||
|
|
||||||
|
self._check_and_update_cudagraph_mode(attention_backend_list,
|
||||||
|
kv_cache_config.kv_cache_groups)
|
||||||
|
|
||||||
for i, kv_cache_group_spec in enumerate(
|
for i, kv_cache_group_spec in enumerate(
|
||||||
kv_cache_config.kv_cache_groups):
|
kv_cache_config.kv_cache_groups):
|
||||||
attn_backends = get_attn_backends_for_group( # type: ignore
|
attn_backends = get_attn_backends_for_group( # type: ignore
|
||||||
kv_cache_group_spec)
|
kv_cache_group_spec)
|
||||||
self.attn_groups.append(create_attn_groups(attn_backends, i))
|
self.attn_groups.append(create_attn_groups(attn_backends[0], i))
|
||||||
|
|
||||||
# Calculate reorder batch threshold (if needed)
|
# Calculate reorder batch threshold (if needed)
|
||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
@@ -2855,214 +2869,26 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def initialize_aclgraph_capture(self) -> None:
|
def _check_and_update_cudagraph_mode(
|
||||||
min_ag_support = AttentionCGSupport.ALWAYS
|
self,
|
||||||
min_ag_builder_name = None
|
attention_backends: list[set[type[AttentionBackend]]],
|
||||||
|
kv_cache_groups: list[KVCacheGroupSpec],
|
||||||
for attn_group in self._attn_group_iterator():
|
) -> None:
|
||||||
builder = attn_group.get_metadata_builder()
|
super()._check_and_update_cudagraph_mode(attention_backends,
|
||||||
graph_support = None
|
kv_cache_groups)
|
||||||
if hasattr(builder, 'aclgraph_support'):
|
|
||||||
graph_support = builder.aclgraph_support.value
|
|
||||||
builder_aclgraph = builder.aclgraph_support
|
|
||||||
else:
|
|
||||||
graph_support = builder._cudagraph_support.value
|
|
||||||
builder_aclgraph = builder._cudagraph_support
|
|
||||||
if graph_support < min_ag_support.value:
|
|
||||||
min_ag_support = builder_aclgraph
|
|
||||||
min_ag_builder_name = builder.__class__.__name__
|
|
||||||
|
|
||||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
|
||||||
splitting_ops_contain_attention = (
|
|
||||||
self.compilation_config.splitting_ops is not None
|
|
||||||
and all(op in self.compilation_config.splitting_ops for op in [
|
|
||||||
"vllm.mla_forward",
|
|
||||||
]))
|
|
||||||
|
|
||||||
# Flexible resolve the aclgraph mode
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
||||||
# check graph for mixed batch is supported
|
|
||||||
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
|
||||||
and min_ag_support != AttentionCGSupport.ALWAYS:
|
|
||||||
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
|
||||||
f"with {min_ag_builder_name} backend (support: "
|
|
||||||
f"{min_ag_support})")
|
|
||||||
if min_ag_support == AttentionCGSupport.NEVER:
|
|
||||||
# if not supported any full graphs, just raise it.
|
|
||||||
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
|
||||||
"make sure compilation level is piecewise"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# attempt to resolve the full graph related mode
|
|
||||||
if splitting_ops_contain_attention:
|
|
||||||
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
||||||
CUDAGraphMode.FULL_AND_PIECEWISE)
|
|
||||||
else:
|
|
||||||
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
||||||
CUDAGraphMode.FULL_DECODE_ONLY)
|
|
||||||
logger.warning(msg)
|
|
||||||
|
|
||||||
# double check that we can support full graph if they are requested
|
|
||||||
# even after automatic downgrades
|
|
||||||
if aclgraph_mode.has_full_cudagraphs() \
|
|
||||||
and min_ag_support == AttentionCGSupport.NEVER:
|
|
||||||
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
|
||||||
f"supported with {min_ag_builder_name} backend ("
|
|
||||||
f"support:{min_ag_support}) "
|
|
||||||
"; please try cudagraph_mode=PIECEWISE, "
|
|
||||||
"and make sure compilation level is piecewise")
|
|
||||||
|
|
||||||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
||||||
and aclgraph_mode.separate_routine()
|
|
||||||
and self.uniform_decode_query_len > 1):
|
|
||||||
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
|
|
||||||
self.uniform_decode_query_len,
|
|
||||||
self.parallel_config.tensor_parallel_size)
|
|
||||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
|
||||||
self.cudagraph_batch_sizes = (capture_sizes
|
|
||||||
if capture_sizes is not None else [])
|
|
||||||
|
|
||||||
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
||||||
# we set the graph params right before initializing the keys.
|
# we set the graph params right before initializing the keys.
|
||||||
set_graph_params(self.cudagraph_batch_sizes)
|
if self.use_aclgraph:
|
||||||
if self.speculative_config:
|
set_graph_params(self.cudagraph_batch_sizes)
|
||||||
set_draft_graph_params(self.cudagraph_batch_sizes)
|
if self.speculative_config:
|
||||||
|
set_draft_graph_params(self.cudagraph_batch_sizes)
|
||||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
|
||||||
self.compilation_config.cudagraph_mode,
|
|
||||||
self.uniform_decode_query_len)
|
|
||||||
|
|
||||||
def _capture_aclgraphs(self, compilation_cases: list[int],
|
|
||||||
aclgraph_runtime_mode: CUDAGraphMode,
|
|
||||||
uniform_decode: bool):
|
|
||||||
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
|
||||||
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
|
||||||
CUDAGraphMode.PIECEWISE]
|
|
||||||
|
|
||||||
# Only rank 0 should print progress bar during capture
|
|
||||||
if is_global_first_rank():
|
|
||||||
logger.info(
|
|
||||||
"Starting to capture ACL graphs for cases: %s, "
|
|
||||||
"mode: %s, uniform_decode: %s", compilation_cases,
|
|
||||||
aclgraph_runtime_mode.name, uniform_decode)
|
|
||||||
compilation_cases = tqdm(
|
|
||||||
compilation_cases,
|
|
||||||
disable=not self.load_config.use_tqdm_on_load,
|
|
||||||
desc="Capturing ACL graphs ({}, {})".format(
|
|
||||||
"decode" if uniform_decode else "mixed prefill-decode",
|
|
||||||
aclgraph_runtime_mode.name))
|
|
||||||
|
|
||||||
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
|
||||||
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
|
|
||||||
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
|
|
||||||
if not self.get_kv_cache_spec():
|
|
||||||
self._dummy_run(2,
|
|
||||||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
|
||||||
force_attention=force_attention,
|
|
||||||
uniform_decode=uniform_decode)
|
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
|
||||||
for num_tokens in compilation_cases:
|
|
||||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
|
||||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
|
||||||
# But be careful, warm up with `NONE`is orthogonal to
|
|
||||||
# if we want to warm up attention or not. This is
|
|
||||||
# different from the case where `FULL` implies capture
|
|
||||||
# attention while `PIECEWISE` implies no attention.
|
|
||||||
self._dummy_run(num_tokens,
|
|
||||||
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
|
||||||
force_attention=force_attention,
|
|
||||||
uniform_decode=uniform_decode)
|
|
||||||
self._dummy_run(num_tokens,
|
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
||||||
force_attention=force_attention,
|
|
||||||
uniform_decode=uniform_decode)
|
|
||||||
|
|
||||||
def _capture_model(self):
|
|
||||||
if not self.use_aclgraph:
|
|
||||||
logger.warning(
|
|
||||||
"Skipping ACL graph capture. To turn on ACL graph capture, "
|
|
||||||
"ensure `aclraph_mode` was not manually set to `NONE`")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.initialize_aclgraph_capture()
|
|
||||||
|
|
||||||
set_cudagraph_capturing_enabled(True)
|
|
||||||
# Trigger ACL graph capture for specific shapes.
|
|
||||||
# Capture the large shapes first so that the smaller shapes
|
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
|
||||||
with graph_capture(device=self.device):
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
||||||
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
|
||||||
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
|
||||||
|
|
||||||
# make sure we capture the largest batch size first
|
|
||||||
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._capture_aclgraphs(
|
|
||||||
compilation_cases,
|
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
||||||
uniform_decode=False)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
error_code = '0x7020023'
|
|
||||||
pattern = r'retCode=([^,\s\.]+)'
|
|
||||||
match = re.search(pattern, error_msg)
|
|
||||||
if match:
|
|
||||||
retCode = match.group(1)
|
|
||||||
# Determine whether the error message is caused by stream capture failure.
|
|
||||||
if match and retCode == error_code:
|
|
||||||
logger.error(
|
|
||||||
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
|
||||||
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
|
||||||
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
|
||||||
"Recommended solutions:\n"
|
|
||||||
"1. Manually configure the compilation_config parameter "
|
|
||||||
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
|
||||||
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
|
|
||||||
f"{str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
|
||||||
aclgraph_mode.separate_routine():
|
|
||||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
|
||||||
self.uniform_decode_query_len
|
|
||||||
decode_cudagraph_batch_sizes = [
|
|
||||||
x for x in self.cudagraph_batch_sizes if
|
|
||||||
x <= max_num_tokens and x >= self.uniform_decode_query_len
|
|
||||||
]
|
|
||||||
compilation_cases_decode = list(
|
|
||||||
reversed(decode_cudagraph_batch_sizes))
|
|
||||||
self._capture_aclgraphs(
|
|
||||||
compilation_cases=compilation_cases_decode,
|
|
||||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
|
||||||
uniform_decode=True)
|
|
||||||
|
|
||||||
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
|
||||||
# capturing will be detected and raise an error after here.
|
|
||||||
# Note: We don't put it into graph_capture context manager because
|
|
||||||
# we may doing lazy capturing in future that still allows capturing
|
|
||||||
# after here.
|
|
||||||
set_cudagraph_capturing_enabled(False)
|
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
|
parent_module_name = self.__class__.__base__.__module__
|
||||||
compilation_counter.num_gpu_runner_capture_triggers += 1
|
with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper(
|
||||||
|
parent_module_name):
|
||||||
start_time = time.perf_counter()
|
super().capture_model()
|
||||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
||||||
|
|
||||||
self._capture_model()
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
|
||||||
# This usually takes 5~20 seconds.
|
|
||||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
||||||
elapsed_time, npu_graph_size / (1 << 30))
|
|
||||||
|
|
||||||
def _update_tokens_for_pcp(self, tokens):
|
def _update_tokens_for_pcp(self, tokens):
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@@ -3473,6 +3299,8 @@ def _torch_cuda_wrapper():
|
|||||||
torch.cuda.default_stream = torch.npu.default_stream
|
torch.cuda.default_stream = torch.npu.default_stream
|
||||||
torch.cuda.current_stream = torch.npu.current_stream
|
torch.cuda.current_stream = torch.npu.current_stream
|
||||||
torch.cuda.stream = torch.npu.stream
|
torch.cuda.stream = torch.npu.stream
|
||||||
|
torch.cuda.synchronize = torch.npu.synchronize
|
||||||
|
torch.cuda.mem_get_info = torch.npu.mem_get_info
|
||||||
yield
|
yield
|
||||||
except Exception:
|
except Exception:
|
||||||
torch.cuda.Event = _EventPlaceholder
|
torch.cuda.Event = _EventPlaceholder
|
||||||
@@ -3480,6 +3308,8 @@ def _torch_cuda_wrapper():
|
|||||||
torch.cuda.default_stream = _StreamPlaceholder
|
torch.cuda.default_stream = _StreamPlaceholder
|
||||||
torch.cuda.current_stream = _StreamPlaceholder
|
torch.cuda.current_stream = _StreamPlaceholder
|
||||||
torch.cuda.stream = _StreamPlaceholder
|
torch.cuda.stream = _StreamPlaceholder
|
||||||
|
torch.cuda.synchronize = _StreamPlaceholder
|
||||||
|
torch.cuda.mem_get_info = _StreamPlaceholder
|
||||||
finally:
|
finally:
|
||||||
# if anything goes wrong, just patch it with a placeholder
|
# if anything goes wrong, just patch it with a placeholder
|
||||||
torch.cuda.Event = _EventPlaceholder
|
torch.cuda.Event = _EventPlaceholder
|
||||||
@@ -3487,3 +3317,16 @@ def _torch_cuda_wrapper():
|
|||||||
torch.cuda.default_stream = torch.npu.default_stream
|
torch.cuda.default_stream = torch.npu.default_stream
|
||||||
torch.cuda.current_stream = torch.npu.current_stream
|
torch.cuda.current_stream = torch.npu.current_stream
|
||||||
torch.cuda.stream = torch.npu.stream
|
torch.cuda.stream = torch.npu.stream
|
||||||
|
torch.cuda.synchronize = torch.npu.synchronize
|
||||||
|
torch.cuda.mem_get_info = torch.npu.mem_get_info
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This method will be removed subsequently and implemented in platform.
|
||||||
|
@contextmanager
|
||||||
|
def _replace_gpu_model_runner_function_wrapper(target_module_name):
|
||||||
|
try:
|
||||||
|
target_module = sys.modules[target_module_name]
|
||||||
|
setattr(target_module, "graph_capture", graph_capture)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
setattr(target_module, "graph_capture", graph_capture)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, ModelRunnerOutput)
|
DraftTokenIds, ModelRunnerOutput)
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
from vllm.v1.worker.workspace import init_workspace_manager
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||||
@@ -231,6 +232,9 @@ class NPUWorker(WorkerBase):
|
|||||||
# in ray scenario. see https://github.com/vllm-project/vllm/pull/26845
|
# in ray scenario. see https://github.com/vllm-project/vllm/pull/26845
|
||||||
# for more details
|
# for more details
|
||||||
self.device = self._init_device()
|
self.device = self._init_device()
|
||||||
|
# Initialize workspace manager
|
||||||
|
num_ubatches = 1
|
||||||
|
init_workspace_manager(self.device, num_ubatches)
|
||||||
# Init ModelRunner here, so that we have access to self.device.
|
# Init ModelRunner here, so that we have access to self.device.
|
||||||
if self.use_v2_model_runner:
|
if self.use_v2_model_runner:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user