feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it?
Fix mtp mode ut
### Does this PR introduce _any_ user-facing change?
Nothing
### How was this patch tested?
This can be tested in the same way as a unit test.
- vLLM version: v0.10.0
- vLLM main:
53415653ff
Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
@@ -1,43 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import random
|
import os
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
@pytest.fixture
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
def test_prompts():
|
|
||||||
prompt_types = ["repeat", "sentence"]
|
|
||||||
num_prompts = 10
|
|
||||||
prompts = []
|
|
||||||
|
|
||||||
random.seed(0)
|
|
||||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
|
||||||
|
|
||||||
# Generate a mixed batch of prompts, some of which can be easily
|
|
||||||
# predicted by n-gram matching and some which likely cannot.
|
|
||||||
for kind in random_prompt_type_choices:
|
|
||||||
word_choices = ["test", "temp", "hello", "where"]
|
|
||||||
word = random.choice(word_choices)
|
|
||||||
if kind == "repeat":
|
|
||||||
prompt = f"""
|
|
||||||
please repeat the word '{word}' 10 times.
|
|
||||||
give no other output than the word at least ten times in a row,
|
|
||||||
in lowercase with spaces between each word and without quotes.
|
|
||||||
"""
|
|
||||||
elif kind == "sentence":
|
|
||||||
prompt = f"""
|
|
||||||
please give a ten-word sentence that
|
|
||||||
uses the word {word} at least once.
|
|
||||||
give no other output than that simple sentence without quotes.
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown prompt type: {kind}")
|
|
||||||
prompts.append([{"role": "user", "content": prompt}])
|
|
||||||
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -50,39 +20,56 @@ def model_name():
|
|||||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
|
|
||||||
def test_mtp_correctness(
|
def test_mtp_correctness(
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
'''
|
'''
|
||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
should be the same when using mtp speculative decoding.
|
should be the same when using mtp speculative decoding.
|
||||||
'''
|
'''
|
||||||
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
|
with VllmRunner(model_name,
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
tensor_parallel_size=1,
|
||||||
del ref_llm
|
gpu_memory_utilization=0.7,
|
||||||
|
max_model_len=256,
|
||||||
|
enforce_eager=True) as ref_llm:
|
||||||
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_num_seqs=256,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
speculative_config={
|
||||||
|
"method": "deepseek_mtp",
|
||||||
|
"num_speculative_tokens": 1,
|
||||||
|
},
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=2000,
|
||||||
|
additional_config={"ascend_scheduler_config": {
|
||||||
|
"enabled": False
|
||||||
|
}}) as spec_llm:
|
||||||
|
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
spec_llm = LLM(model=model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
speculative_config={
|
|
||||||
"method": "deepseek_mtp",
|
|
||||||
"num_speculative_tokens": 1,
|
|
||||||
},
|
|
||||||
max_model_len=256,
|
|
||||||
enforce_eager=True)
|
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
|
||||||
matches = 0
|
matches = 0
|
||||||
misses = 0
|
misses = 0
|
||||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
ref_token_ids = ref_output[0][0]
|
||||||
|
spec_token_ids = spec_output[0][0]
|
||||||
|
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
|
||||||
matches += 1
|
matches += 1
|
||||||
else:
|
else:
|
||||||
misses += 1
|
misses += 1
|
||||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
print(f"ref_output: {ref_output[1][0]}")
|
||||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
print(f"spec_output: {spec_output[1][0]}")
|
||||||
|
|
||||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class TestAscendQuantConfig(TestBase):
|
|||||||
def test_get_quant_method_for_fused_moe(self):
|
def test_get_quant_method_for_fused_moe(self):
|
||||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||||
|
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||||
|
|
||||||
# Test skipped layer
|
# Test skipped layer
|
||||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from vllm.distributed.parallel_state import GroupCoordinator
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.torchair.torchair_mla import (
|
from vllm_ascend.torchair.torchair_mla import (
|
||||||
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
|
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
|
||||||
AscendMLATorchairImpl, AscendMLATorchairMetadata,
|
AscendMLATorchairImpl, AscendMLATorchairMetadata,
|
||||||
@@ -398,6 +400,68 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
assert torch.equal(sin_golden, metadata.decode.sin)
|
assert torch.equal(sin_golden, metadata.decode.sin)
|
||||||
assert torch.equal(cos_golden, metadata.decode.cos)
|
assert torch.equal(cos_golden, metadata.decode.cos)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||||
|
def test_build_decode(self, mock_ascend_config):
|
||||||
|
ascend_config = MagicMock()
|
||||||
|
mock_ascend_config.return_value = ascend_config
|
||||||
|
ascend_config.torchair_graph_config.enabled = False
|
||||||
|
|
||||||
|
mock_vllm_config = MagicMock()
|
||||||
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mock_vllm_config.get_head_size.return_value = 64
|
||||||
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_device = 'cpu'
|
||||||
|
model = MagicMock(spec=nn.Module)
|
||||||
|
model.model = MagicMock(spec=nn.Module)
|
||||||
|
|
||||||
|
builder = AscendMLATorchairMetadataBuilder(
|
||||||
|
mock_vllm_config,
|
||||||
|
mock_device,
|
||||||
|
metadata_cls=AscendMLATorchairMetadata)
|
||||||
|
builder.rope_dim = 64
|
||||||
|
|
||||||
|
builder.sin_cache = torch.tensor([10, 10])
|
||||||
|
builder.cos_cache = torch.tensor([10, 10])
|
||||||
|
|
||||||
|
with patch.object(builder,
|
||||||
|
"_get_graph_runner_block_tables",
|
||||||
|
side_effect=lambda x, y: y):
|
||||||
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
|
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||||
|
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||||
|
seq_lens_cpu=torch.tensor([1, 1, 1]),
|
||||||
|
num_reqs=3,
|
||||||
|
num_actual_tokens=3,
|
||||||
|
max_query_len=1,
|
||||||
|
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||||
|
block_table_tensor=torch.zeros((10, 10)),
|
||||||
|
slot_mapping_cpu=torch.tensor(range(20)),
|
||||||
|
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||||
|
positions=torch.tensor([1, 1]),
|
||||||
|
attn_mask=torch.ones((15, 15)),
|
||||||
|
spec_attn_mask=None,
|
||||||
|
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||||
|
|
||||||
|
metadata = builder.build(common_attn_metadata, model)
|
||||||
|
|
||||||
|
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||||
|
self.assertEqual(metadata.num_input_tokens, 0)
|
||||||
|
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||||
|
self.assertEqual(metadata.num_decodes, 3)
|
||||||
|
self.assertEqual(metadata.num_decode_tokens, 3)
|
||||||
|
self.assertEqual(metadata.num_prefills, 0)
|
||||||
|
self.assertEqual(metadata.attn_state,
|
||||||
|
AscendAttentionState.ChunkedPrefill)
|
||||||
|
self.assertIsNone(metadata.prefill)
|
||||||
|
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
|
||||||
|
self.assertEqual(metadata.block_tables.shape[0], 3)
|
||||||
|
self.assertEqual(metadata.block_tables.shape[1], 10)
|
||||||
|
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||||
|
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||||
|
self.assertEqual(metadata.query_start_loc.shape[0], 4)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendMLATorchairImpl(TestBase):
|
class TestAscendMLATorchairImpl(TestBase):
|
||||||
|
|
||||||
|
|||||||
@@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||||
seq_lens = seq_lens[:num_decode_tokens]
|
seq_lens = seq_lens[:num_decode_tokens]
|
||||||
input_positions = input_positions[:num_decode_tokens]
|
input_positions = input_positions[:num_decode_tokens]
|
||||||
block_table = block_table[:num_decode_tokens, ...]
|
block_table = block_table[:num_decode_tokens, ...]
|
||||||
seq_lens_list = seq_lens.tolist()
|
seq_lens_list = seq_lens.tolist()
|
||||||
# TODO(xyx): whether this block is necessary without torchair
|
|
||||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
|
||||||
batch_size = slot_mapping.size(0)
|
|
||||||
if actual_seq_lengths_q[-1] != batch_size \
|
|
||||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
|
||||||
actual_seq_lengths_q[-1] = batch_size
|
|
||||||
|
|
||||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
|
|||||||
@@ -1178,7 +1178,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
"non-grouped topk.")
|
"non-grouped topk.")
|
||||||
self.moe = FusedMoEConfig.make(
|
moe = FusedMoEConfig.make(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
@@ -1188,8 +1188,10 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
in_dtype=params_dtype,
|
in_dtype=params_dtype,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
self.moe_config = moe
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe)
|
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
if self.is_layer_skipped_ascend(prefix,
|
if self.is_layer_skipped_ascend(prefix,
|
||||||
self.packed_modules_mapping):
|
self.packed_modules_mapping):
|
||||||
return AscendUnquantizedFusedMoEMethod(layer.moe)
|
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||||
return AscendFusedMoEMethod(self, prefix,
|
return AscendFusedMoEMethod(self, prefix,
|
||||||
self.packed_modules_mapping)
|
self.packed_modules_mapping)
|
||||||
elif isinstance(layer, VocabParallelEmbedding):
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
|
|||||||
@@ -492,17 +492,17 @@ class AscendMLATorchairMetadataBuilder:
|
|||||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||||
use_torchair_graph = graph_pad_size != -1
|
use_torchair_graph = graph_pad_size != -1
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||||
seq_lens = seq_lens[:num_decode_tokens]
|
seq_lens = seq_lens[:num_decode_tokens]
|
||||||
input_positions = input_positions[:num_decode_tokens]
|
input_positions = input_positions[:num_decode_tokens]
|
||||||
block_table = block_table[:num_decode_tokens, ...]
|
block_table = block_table[:num_decode_tokens, ...]
|
||||||
|
num_token_pad_size = 0
|
||||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||||
AscendAttentionState.DecodeOnly,
|
AscendAttentionState.DecodeOnly,
|
||||||
AscendAttentionState.SpecDecoding
|
AscendAttentionState.SpecDecoding
|
||||||
]:
|
]:
|
||||||
num_reqs_pad_size = 0
|
num_reqs_pad_size = 0
|
||||||
num_token_pad_size = 0
|
|
||||||
if graph_pad_size != 0:
|
if graph_pad_size != 0:
|
||||||
pad_value = 0
|
pad_value = 0
|
||||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
num_token_pad_size = graph_pad_size - num_decode_tokens
|
||||||
@@ -535,13 +535,14 @@ class AscendMLATorchairMetadataBuilder:
|
|||||||
device=input_positions.device)
|
device=input_positions.device)
|
||||||
input_positions = torch.cat(
|
input_positions = torch.cat(
|
||||||
[input_positions, position_padding])
|
[input_positions, position_padding])
|
||||||
actual_seq_lengths_q = query_start_loc[1:].tolist(
|
actual_seq_lengths_q = (
|
||||||
) + common_attn_metadata.actual_seq_lengths_q[
|
actual_seq_lengths_q + common_attn_metadata.
|
||||||
num_reqs:num_reqs + num_reqs_pad_size]
|
actual_seq_lengths_q[num_reqs:num_reqs +
|
||||||
|
num_reqs_pad_size])
|
||||||
else:
|
else:
|
||||||
seq_lens_list = seq_lens.tolist()
|
seq_lens_list = seq_lens.tolist()
|
||||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||||
batch_size = slot_mapping.size(0)
|
batch_size = num_decode_tokens + num_token_pad_size
|
||||||
if actual_seq_lengths_q[-1] != batch_size \
|
if actual_seq_lengths_q[-1] != batch_size \
|
||||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||||
actual_seq_lengths_q[-1] = batch_size
|
actual_seq_lengths_q[-1] = batch_size
|
||||||
|
|||||||
@@ -190,11 +190,6 @@ class MtpProposer:
|
|||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
|
||||||
if attn_metadata.prefill is not None:
|
|
||||||
attn_metadata.prefill.query_lens = query_lens.cpu()
|
|
||||||
attn_metadata.prefill.input_positions = target_positions
|
|
||||||
attn_metadata.prefill.seq_lens = seq_lens
|
|
||||||
|
|
||||||
if not self.torchair_graph_enabled:
|
if not self.torchair_graph_enabled:
|
||||||
# torch mode need to update num_tokens_across_dp
|
# torch mode need to update num_tokens_across_dp
|
||||||
# TODO: adapt enable_dbo later
|
# TODO: adapt enable_dbo later
|
||||||
@@ -213,6 +208,7 @@ class MtpProposer:
|
|||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=num_tokens):
|
num_actual_tokens=num_tokens):
|
||||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
@@ -315,6 +311,7 @@ class MtpProposer:
|
|||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=0):
|
num_actual_tokens=0):
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
|
|||||||
@@ -47,9 +47,14 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
|||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import (init_ascend_soc_version,
|
from vllm_ascend.utils import (init_ascend_soc_version,
|
||||||
register_ascend_customop, sleep_mode_enabled,
|
register_ascend_customop, sleep_mode_enabled,
|
||||||
try_register_lib)
|
try_register_lib, vllm_version_is)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
if not vllm_version_is("0.10.1.1"):
|
||||||
|
from vllm.v1.outputs import DraftTokenIds
|
||||||
|
else:
|
||||||
|
DraftTokenIds = None
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(WorkerBase):
|
class NPUWorker(WorkerBase):
|
||||||
|
|
||||||
@@ -343,3 +348,6 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
||||||
return self.model_runner.get_supported_tasks()
|
return self.model_runner.get_supported_tasks()
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
|
return self.model_runner.take_draft_token_ids()
|
||||||
Reference in New Issue
Block a user