[Feat]Make full graph mode compalible with MTP (#3276)
### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
@@ -20,6 +21,7 @@ def mtp_correctness(
|
|||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
|
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||||
):
|
):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -38,6 +40,10 @@ def mtp_correctness(
|
|||||||
enforce_eager=False) as ref_llm:
|
enforce_eager=False) as ref_llm:
|
||||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
|
graph_mode_str = "PIECEWISE"
|
||||||
|
if graph_mode == CUDAGraphMode.FULL:
|
||||||
|
graph_mode_str = "FULL"
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model_name,
|
model_name,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
@@ -51,6 +57,8 @@ def mtp_correctness(
|
|||||||
},
|
},
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
cudagraph_mode=graph_mode_str),
|
||||||
additional_config={"ascend_scheduler_config": {
|
additional_config={"ascend_scheduler_config": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
}}) as spec_llm:
|
}}) as spec_llm:
|
||||||
@@ -74,15 +82,29 @@ def mtp_correctness(
|
|||||||
del spec_llm
|
del spec_llm
|
||||||
|
|
||||||
|
|
||||||
def test_mtp1_correctness(
|
def test_mtp1_correctness_piecewise_graph(
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
mtp_correctness(sampling_config, model_name, 1)
|
mtp_correctness(sampling_config, model_name, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_mtp2_correctness(
|
def test_mtp2_correctness_piecewise_graph(
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
mtp_correctness(sampling_config, model_name, 2)
|
mtp_correctness(sampling_config, model_name, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp1_correctness_full_graph(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp2_correctness_full_graph(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
@@ -16,9 +17,10 @@ def model_name():
|
|||||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||||
|
|
||||||
|
|
||||||
def test_mtp_torchair_correctness(
|
def mtp_torchair_correctness(
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||||
):
|
):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -44,6 +46,11 @@ def test_mtp_torchair_correctness(
|
|||||||
"multistream_overlap_shared_expert": "True"
|
"multistream_overlap_shared_expert": "True"
|
||||||
}) as ref_llm:
|
}) as ref_llm:
|
||||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
|
graph_mode_str = "PIECEWISE"
|
||||||
|
if graph_mode == CUDAGraphMode.FULL:
|
||||||
|
graph_mode_str = "FULL"
|
||||||
|
|
||||||
with VllmRunner(model_name,
|
with VllmRunner(model_name,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
max_num_seqs=256,
|
max_num_seqs=256,
|
||||||
@@ -56,6 +63,8 @@ def test_mtp_torchair_correctness(
|
|||||||
},
|
},
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
cudagraph_mode=graph_mode_str),
|
||||||
additional_config={
|
additional_config={
|
||||||
"torchair_graph_config": {
|
"torchair_graph_config": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
@@ -81,3 +90,17 @@ def test_mtp_torchair_correctness(
|
|||||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp_torchair_correctness_piecewise(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_torchair_correctness(sampling_config, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp_torchair_correctness_full(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL)
|
||||||
|
|||||||
@@ -448,6 +448,7 @@ class TestNPUWorker(TestBase):
|
|||||||
worker.compilation_config = MagicMock()
|
worker.compilation_config = MagicMock()
|
||||||
worker.compilation_config.cudagraph_mode = MagicMock()
|
worker.compilation_config.cudagraph_mode = MagicMock()
|
||||||
mock_model_runner = MagicMock()
|
mock_model_runner = MagicMock()
|
||||||
|
mock_decode_token_per_req = mock_model_runner.decode_token_per_req
|
||||||
worker.model_runner = mock_model_runner
|
worker.model_runner = mock_model_runner
|
||||||
|
|
||||||
# Test execute_dummy_batch
|
# Test execute_dummy_batch
|
||||||
@@ -455,7 +456,9 @@ class TestNPUWorker(TestBase):
|
|||||||
|
|
||||||
# Verify call
|
# Verify call
|
||||||
mock_model_runner._dummy_run.assert_called_once_with(
|
mock_model_runner._dummy_run.assert_called_once_with(
|
||||||
num_tokens=1, uniform_decode=True, force_attention=False)
|
num_tokens=mock_decode_token_per_req,
|
||||||
|
uniform_decode=True,
|
||||||
|
force_attention=False)
|
||||||
|
|
||||||
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
||||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
|||||||
class AscendMLAMetadataBuilder:
|
class AscendMLAMetadataBuilder:
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_BATCH
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@@ -209,7 +209,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
got {self.decode_threshold}"
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
self.reorder_batch_threshold = self.decode_threshold
|
self.reorder_batch_threshold = self.decode_threshold
|
||||||
|
|
||||||
if self.chunked_prefill_enabled:
|
if self.chunked_prefill_enabled:
|
||||||
self.chunked_prefill_workspace_size = min(
|
self.chunked_prefill_workspace_size = min(
|
||||||
# Max sure there is enough for 8 full length request or at least
|
# Max sure there is enough for 8 full length request or at least
|
||||||
@@ -427,10 +426,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos)
|
cos=cos)
|
||||||
else:
|
else:
|
||||||
cos[:num_decodes,
|
cos[:num_decode_tokens,
|
||||||
...] = self.cos_cache[input_positions].unsqueeze(
|
...] = self.cos_cache[input_positions].unsqueeze(
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
sin[:num_decodes,
|
sin[:num_decode_tokens,
|
||||||
...] = self.sin_cache[input_positions].unsqueeze(
|
...] = self.sin_cache[input_positions].unsqueeze(
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
|
|
||||||
@@ -442,8 +441,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
sin=sin[:num_decodes, ...],
|
sin=sin[:num_decode_tokens, ...],
|
||||||
cos=cos[:num_decodes, ...])
|
cos=cos[:num_decode_tokens, ...])
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@@ -469,7 +468,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||||
model: Optional[nn.Module] = None,
|
model: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
if attn_state == AscendAttentionState.DecodeOnly:
|
if attn_state in {
|
||||||
|
AscendAttentionState.DecodeOnly,
|
||||||
|
AscendAttentionState.SpecDecoding
|
||||||
|
}:
|
||||||
attn_metadata = self.build(
|
attn_metadata = self.build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
@@ -477,7 +479,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata.attn_state = attn_state
|
attn_metadata.attn_state = attn_state
|
||||||
@@ -955,7 +957,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.SpecDecoding,
|
AscendAttentionState.SpecDecoding,
|
||||||
AscendAttentionState.ChunkedPrefill
|
AscendAttentionState.ChunkedPrefill,
|
||||||
|
AscendAttentionState.DecodeOnly,
|
||||||
] and self.speculative_config is not None:
|
] and self.speculative_config is not None:
|
||||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||||
input_layout = "TND"
|
input_layout = "TND"
|
||||||
|
|||||||
@@ -245,7 +245,8 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
|
|
||||||
|
|
||||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||||
|
speculative_config):
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
# for each layer's attention op in the graph.
|
# for each layer's attention op in the graph.
|
||||||
@@ -260,9 +261,19 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
seq_lens_list, actual_seq_lengths, workspace, attn_output,
|
seq_lens_list, actual_seq_lengths, workspace, attn_output,
|
||||||
softmax_lse) = param
|
softmax_lse) = param
|
||||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||||
len(seq_lens_list))
|
actual_seq_lengths = forward_context.attn_metadata[
|
||||||
|
key].decode.actual_seq_lengths_q
|
||||||
|
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||||
|
seq_lens_list = seq_lens_list + [0] * (
|
||||||
|
runtime_shape // spec_multiple - len(seq_lens_list))
|
||||||
|
actual_seq_lengths = [
|
||||||
|
spec_multiple * (i + 1)
|
||||||
|
for i in range(runtime_shape // spec_multiple)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||||
|
len(seq_lens_list))
|
||||||
with torch.npu.stream(update_stream):
|
with torch.npu.stream(update_stream):
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
|
|
||||||
|
|||||||
@@ -345,6 +345,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.speculative_config.method, self.vllm_config,
|
self.speculative_config.method, self.vllm_config,
|
||||||
self.device, self)
|
self.device, self)
|
||||||
self.rejection_sampler = AscendRejectionSampler()
|
self.rejection_sampler = AscendRejectionSampler()
|
||||||
|
self.actual_seq_lengths_q = list(
|
||||||
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||||
|
self.decode_token_per_req))
|
||||||
|
|
||||||
# Persistent batch.
|
# Persistent batch.
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
@@ -366,13 +369,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.model_config.use_mla and \
|
if self.vllm_config.model_config.use_mla and \
|
||||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
self.cos = torch.ones(self.max_num_reqs,
|
self.cos = torch.ones(self.max_num_reqs *
|
||||||
|
self.decode_token_per_req,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
rope_dim,
|
rope_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
self.sin = torch.zeros(self.max_num_reqs,
|
self.sin = torch.zeros(self.max_num_reqs *
|
||||||
|
self.decode_token_per_req,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
rope_dim,
|
rope_dim,
|
||||||
@@ -1554,7 +1559,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
# FIXME: Try using `auto_dispatch_capture=True`
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
update_mla_attn_params(self.update_stream, forward_context,
|
update_mla_attn_params(self.update_stream, forward_context,
|
||||||
maybe_padded_num_tokens)
|
maybe_padded_num_tokens,
|
||||||
|
self.speculative_config)
|
||||||
else:
|
else:
|
||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
maybe_padded_num_tokens)
|
maybe_padded_num_tokens)
|
||||||
@@ -2255,7 +2261,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_table_tensor = self.input_batch.block_table[
|
block_table_tensor = self.input_batch.block_table[
|
||||||
kv_cache_group_id].get_device_tensor()
|
kv_cache_group_id].get_device_tensor()
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=torch.tensor(
|
||||||
|
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32),
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||||
1],
|
1],
|
||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu,
|
||||||
@@ -2275,12 +2284,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cos=self.cos,
|
cos=self.cos,
|
||||||
sin=self.sin,
|
sin=self.sin,
|
||||||
)
|
)
|
||||||
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
if self.speculative_config and \
|
||||||
|
self.speculative_config.method == "deepseek_mtp":
|
||||||
|
attn_state = AscendAttentionState.SpecDecoding
|
||||||
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
builder = attn_group.get_metadata_builder()
|
builder = attn_group.get_metadata_builder()
|
||||||
attn_metadata_i = builder.build_for_graph_capture(
|
attn_metadata_i = builder.build_for_graph_capture(
|
||||||
common_attn_metadata, AscendAttentionState.DecodeOnly,
|
common_attn_metadata, attn_state, self.get_model())
|
||||||
self.get_model())
|
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
@@ -2301,7 +2313,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
# FIXME: Try using `auto_dispatch_capture=True`
|
# FIXME: Try using `auto_dispatch_capture=True`
|
||||||
update_mla_attn_params(self.update_stream, forward_context,
|
update_mla_attn_params(self.update_stream, forward_context,
|
||||||
positions.shape[0])
|
positions.shape[0],
|
||||||
|
self.speculative_config)
|
||||||
else:
|
else:
|
||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
positions.shape[0])
|
positions.shape[0])
|
||||||
@@ -3388,23 +3401,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
CUDAGraphMode.FULL_DECODE_ONLY)
|
CUDAGraphMode.FULL_DECODE_ONLY)
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
# check that if spec-decode + decode full-graphs is supported
|
|
||||||
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
||||||
and self.uniform_decode_query_len > 1 and min_ag_support.value
|
|
||||||
< AttentionCGSupport.UNIFORM_BATCH.value):
|
|
||||||
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
|
|
||||||
f" with spec-decode for attention backend "
|
|
||||||
f"{min_ag_builder_name} (support: {min_ag_support})")
|
|
||||||
if splitting_ops_contain_attention:
|
|
||||||
msg += "; setting cudagraph_mode=PIECEWISE"
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
|
||||||
CUDAGraphMode.PIECEWISE
|
|
||||||
else:
|
|
||||||
msg += "; setting cudagraph_mode=NONE"
|
|
||||||
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
|
||||||
CUDAGraphMode.NONE
|
|
||||||
logger.warning(msg)
|
|
||||||
|
|
||||||
# double check that we can support full graph if they are requested
|
# double check that we can support full graph if they are requested
|
||||||
# even after automatic downgrades
|
# even after automatic downgrades
|
||||||
if aclgraph_mode.has_full_cudagraphs() \
|
if aclgraph_mode.has_full_cudagraphs() \
|
||||||
|
|||||||
@@ -361,9 +361,10 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
self.model_runner._dummy_run(num_tokens=1,
|
self.model_runner._dummy_run(
|
||||||
uniform_decode=True,
|
num_tokens=self.model_runner.decode_token_per_req,
|
||||||
force_attention=force_attention)
|
uniform_decode=True,
|
||||||
|
force_attention=force_attention)
|
||||||
|
|
||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
|
|||||||
Reference in New Issue
Block a user