[Feature] Support for cross-attention and whisper model (#5592)
### What this PR does / why we need it?
To solve the problem of the
issue:https://github.com/vllm-project/vllm-ascend/issues/2262
- support for cross-attention when the model is encoder-decoder
- support for whisper model
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: gh924 <guihao2@huawei.com>
Co-authored-by: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from modelscope import snapshot_download # type: ignore
|
from modelscope import snapshot_download # type: ignore
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
@@ -32,6 +34,10 @@ MINICPM_MODELS = [
|
|||||||
"OpenBMB/MiniCPM4-0.5B",
|
"OpenBMB/MiniCPM4-0.5B",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
WHISPER_MODELS = [
|
||||||
|
"openai-mirror/whisper-large-v3-turbo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MINICPM_MODELS)
|
@pytest.mark.parametrize("model", MINICPM_MODELS)
|
||||||
def test_minicpm(model) -> None:
|
def test_minicpm(model) -> None:
|
||||||
@@ -44,3 +50,26 @@ def test_minicpm(model) -> None:
|
|||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
gpu_memory_utilization=0.7) as runner:
|
gpu_memory_utilization=0.7) as runner:
|
||||||
runner.generate_greedy(example_prompts, max_tokens)
|
runner.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", WHISPER_MODELS)
|
||||||
|
def test_whisper(model) -> None:
|
||||||
|
prompts = ["<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"]
|
||||||
|
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
stop_token_ids=None)
|
||||||
|
|
||||||
|
with VllmRunner(snapshot_download(model),
|
||||||
|
max_model_len=448,
|
||||||
|
max_num_seqs=5,
|
||||||
|
dtype="bfloat16",
|
||||||
|
block_size=128,
|
||||||
|
gpu_memory_utilization=0.9) as runner:
|
||||||
|
outputs = runner.generate(prompts=prompts,
|
||||||
|
audios=audios,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
|
assert outputs is not None, "Generated outputs should not be None."
|
||||||
|
assert len(outputs) > 0, "Generated outputs should not be empty."
|
||||||
|
|||||||
@@ -320,26 +320,3 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_fused_infer_attention_score.assert_called_once()
|
mock_fused_infer_attention_score.assert_called_once()
|
||||||
|
|
||||||
assert output.shape == (10, 8, 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
def test_forward_raise_error(self, mock_paged_attention):
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
|
||||||
metadata.query_lens = torch.tensor([10])
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = 10
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
self.impl_error.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from vllm.utils.math_utils import cdiv
|
|||||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
AttentionMetadataBuilder)
|
AttentionMetadataBuilder)
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
@@ -256,6 +256,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||||
|
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
|
||||||
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
|
|
||||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||||
@@ -502,6 +505,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_size = 128
|
block_size = 128
|
||||||
block_table = None
|
block_table = None
|
||||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||||
|
if self.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
|
||||||
|
dim=0).tolist()
|
||||||
elif attn_metadata.attn_state == \
|
elif attn_metadata.attn_state == \
|
||||||
AscendAttentionState.PrefillCacheHit:
|
AscendAttentionState.PrefillCacheHit:
|
||||||
batch_size = attn_metadata.seq_lens.shape[0]
|
batch_size = attn_metadata.seq_lens.shape[0]
|
||||||
@@ -583,7 +589,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
= self._get_fia_params(key, value, attn_metadata)
|
= self._get_fia_params(key, value, attn_metadata)
|
||||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||||
query = query[:num_tokens]
|
query = query[:num_tokens]
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||||
key = key[:num_tokens]
|
key = key[:num_tokens]
|
||||||
value = value[:num_tokens]
|
value = value[:num_tokens]
|
||||||
# Get workspace from cache or calculate it if not present.
|
# Get workspace from cache or calculate it if not present.
|
||||||
@@ -675,23 +681,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if self.key_cache is None:
|
if self.key_cache is None:
|
||||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
|
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
|
||||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||||
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
# TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping.
|
||||||
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
# Should check if the 0 dim of slot_mapping must equal to the 0 dim of key.
|
||||||
# If it's necessary, the slots should be sliced.
|
# If it's necessary, the slots should be sliced.
|
||||||
torch_npu.npu_scatter_pa_kv_cache(
|
torch_npu.npu_scatter_pa_kv_cache(
|
||||||
key=key[:attn_metadata.num_actual_tokens],
|
key=key[:attn_metadata.num_actual_tokens]
|
||||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
if not encoder_decoder else key,
|
||||||
|
value=value[:attn_metadata.num_actual_tokens].contiguous()
|
||||||
|
if not encoder_decoder else value,
|
||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=self.value_cache,
|
value_cache=self.value_cache,
|
||||||
slot_mapping=slots)
|
slot_mapping=slots)
|
||||||
else:
|
else:
|
||||||
torch_npu._npu_reshape_and_cache(
|
torch_npu._npu_reshape_and_cache(
|
||||||
key=key[:attn_metadata.num_actual_tokens],
|
key=key[:attn_metadata.num_actual_tokens]
|
||||||
value=value[:attn_metadata.num_actual_tokens],
|
if not encoder_decoder else key,
|
||||||
|
value=value[:attn_metadata.num_actual_tokens]
|
||||||
|
if not encoder_decoder else value,
|
||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=self.value_cache,
|
value_cache=self.value_cache,
|
||||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
slot_indices=slots[:attn_metadata.num_actual_tokens]
|
||||||
|
if not encoder_decoder else slots)
|
||||||
if self.is_kv_producer:
|
if self.is_kv_producer:
|
||||||
attn_metadata.reshape_cache_event.record()
|
attn_metadata.reshape_cache_event.record()
|
||||||
return key, value
|
return key, value
|
||||||
@@ -747,18 +759,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
" for AscendAttentionBackendImpl")
|
" for AscendAttentionBackendImpl")
|
||||||
|
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
attn_type = self.attn_type
|
|
||||||
if attn_type not in [
|
|
||||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
|
||||||
]:
|
|
||||||
raise NotImplementedError("Encoder/Decoder cross-attention "
|
|
||||||
"is not implemented for "
|
|
||||||
"PallasAttentionBackendImpl")
|
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
if key is not None and value is not None:
|
||||||
attn_metadata)
|
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
# pooling model branch
|
# pooling model branch
|
||||||
if attn_metadata.model_runner_type == "pooling":
|
if attn_metadata.model_runner_type == "pooling":
|
||||||
attn_output = self._forward_encoder_attention(
|
attn_output = self._forward_encoder_attention(
|
||||||
|
|||||||
@@ -238,6 +238,14 @@ class NPUPlatform(Platform):
|
|||||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
# encoder-decoder models currently only support piecewise mode
|
||||||
|
if model_config and model_config.is_encoder_decoder is True:
|
||||||
|
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
|
logger.warning(
|
||||||
|
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
|
||||||
|
)
|
||||||
|
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
# get custom compile backend for graph fusion
|
# get custom compile backend for graph fusion
|
||||||
compilation_config.oot_compiler = cls.get_compile_backend()
|
compilation_config.oot_compiler = cls.get_compile_backend()
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler
|
|||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec,
|
||||||
EncoderOnlyAttentionSpec,
|
EncoderOnlyAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, KVCacheSpec,
|
KVCacheGroupSpec, KVCacheSpec,
|
||||||
@@ -315,7 +315,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# the block_sizes in the kv cache config.
|
# the block_sizes in the kv cache config.
|
||||||
self.input_batch = NPUInputBatch(
|
self.input_batch = NPUInputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.model_config.max_model_len,
|
max_model_len=max(self.model_config.max_model_len,
|
||||||
|
self.max_encoder_len),
|
||||||
max_num_batched_tokens=self.max_num_tokens,
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
@@ -485,7 +486,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
||||||
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
Optional[torch.Tensor], Optional[torch.Tensor], int, dict[str,
|
||||||
|
Any]]:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@@ -729,7 +731,11 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather
|
# _prepare_inputs may reorder the batch, so we must gather
|
||||||
# multi-modal outputs after that to ensure the correct order
|
# multi-modal outputs after that to ensure the correct order
|
||||||
if self.is_multimodal_model:
|
if vllm_version_is('0.13.0'):
|
||||||
|
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||||
|
else:
|
||||||
|
model_kwargs = self._init_model_kwargs()
|
||||||
|
if self.is_multimodal_model and not self.model_config.is_encoder_decoder:
|
||||||
self.multimodal_cpu_fields = ["grid_thw"]
|
self.multimodal_cpu_fields = ["grid_thw"]
|
||||||
self._prepare_multimodal_fields()
|
self._prepare_multimodal_fields()
|
||||||
with self.maybe_get_ec_connector_output(
|
with self.maybe_get_ec_connector_output(
|
||||||
@@ -796,6 +802,13 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
else:
|
else:
|
||||||
positions = self.positions.gpu[:num_input_tokens]
|
positions = self.positions.gpu[:num_input_tokens]
|
||||||
|
|
||||||
|
# Run the encoder, just like we do with other multimodal inputs.
|
||||||
|
if self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
|
||||||
|
input_ids = self.input_ids.gpu[:total_num_scheduled_tokens]
|
||||||
|
positions = self.positions.gpu[:total_num_scheduled_tokens]
|
||||||
|
encoder_outputs = self._execute_mm_encoder(scheduler_output)
|
||||||
|
model_kwargs.update({"encoder_outputs": encoder_outputs})
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
@@ -880,6 +893,11 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
|
||||||
|
scheduler_output.num_scheduled_tokens or {},
|
||||||
|
kv_cache_group_spec.kv_cache_spec,
|
||||||
|
self.input_batch.num_reqs,
|
||||||
|
)
|
||||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||||
EncoderOnlyAttentionSpec):
|
EncoderOnlyAttentionSpec):
|
||||||
# Encoder-only layers do not have KV cache, so we need to
|
# Encoder-only layers do not have KV cache, so we need to
|
||||||
@@ -977,7 +995,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
prefill_context_parallel_metadata=self.long_seq_metadata,
|
prefill_context_parallel_metadata=self.long_seq_metadata,
|
||||||
max_seq_len=0,
|
max_seq_len=0,
|
||||||
)
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
|
encoder_seq_lens_cpu=encoder_seq_lens_cpu)
|
||||||
|
|
||||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||||
# For pcp + spec decode, we flatten block_table
|
# For pcp + spec decode, we flatten block_table
|
||||||
@@ -1059,7 +1078,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_input_tokens, num_tokens_across_dp,
|
num_input_tokens, num_tokens_across_dp,
|
||||||
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
||||||
input_ids, inputs_embeds, intermediate_tensors,
|
input_ids, inputs_embeds, intermediate_tensors,
|
||||||
max_num_scheduled_tokens)
|
max_num_scheduled_tokens, model_kwargs)
|
||||||
|
|
||||||
# all-gather one hidden-states in sp scene
|
# all-gather one hidden-states in sp scene
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1091,22 +1110,13 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
|
def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
|
||||||
input_ids, positions,
|
input_ids, positions,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds):
|
inputs_embeds, model_kwargs):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
if vllm_version_is('0.13.0'):
|
hidden_states = self.model(input_ids=input_ids,
|
||||||
hidden_states = self.model(
|
positions=positions,
|
||||||
input_ids=input_ids,
|
intermediate_tensors=intermediate_tensors,
|
||||||
positions=positions,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
**model_kwargs)
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
**self._init_model_kwargs(maybe_padded_num_tokens))
|
|
||||||
else:
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
**self._init_model_kwargs())
|
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||||
@@ -1465,9 +1475,9 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
(attn_metadata, positions, num_scheduled_tokens_np,
|
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||||
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
||||||
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
||||||
intermediate_tensors,
|
intermediate_tensors, max_query_len,
|
||||||
max_query_len) = (self._prepare_inputs(scheduler_output,
|
model_kwargs) = (self._prepare_inputs(scheduler_output,
|
||||||
intermediate_tensors))
|
intermediate_tensors))
|
||||||
|
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.eplb_updator.take_update_info_from_eplb_process()
|
self.eplb_updator.take_update_info_from_eplb_process()
|
||||||
@@ -1512,7 +1522,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
hidden_states = self._generate_process_reqs_hidden_states(
|
hidden_states = self._generate_process_reqs_hidden_states(
|
||||||
maybe_padded_num_tokens, input_ids, positions,
|
maybe_padded_num_tokens, input_ids, positions,
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds, model_kwargs)
|
||||||
|
|
||||||
self.maybe_wait_for_kv_save()
|
self.maybe_wait_for_kv_save()
|
||||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||||
@@ -2152,7 +2162,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_sampled_tokens):
|
num_sampled_tokens):
|
||||||
# Make sure padding doesn't exceed max_num_tokens
|
# Make sure padding doesn't exceed max_num_tokens
|
||||||
assert num_tokens_padded <= self.max_num_tokens
|
assert num_tokens_padded <= self.max_num_tokens
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model and not self.model_config.is_encoder_decoder:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||||||
elif self.enable_prompt_embeds:
|
elif self.enable_prompt_embeds:
|
||||||
@@ -2546,7 +2556,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||||
# encounter OOM issue
|
# encounter OOM issue
|
||||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
raw_dsa_k_tensor = None
|
raw_dsa_k_tensor = None
|
||||||
if self.use_sparse:
|
if self.use_sparse:
|
||||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||||
@@ -2721,7 +2731,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
"for more details.")
|
"for more details.")
|
||||||
self.input_batch = NPUInputBatch(
|
self.input_batch = NPUInputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.model_config.max_model_len,
|
max_model_len=max(self.model_config.max_model_len,
|
||||||
|
self.max_encoder_len),
|
||||||
max_num_batched_tokens=self.max_num_tokens,
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
@@ -2889,7 +2900,11 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# encoder-only attention does not need KV cache.
|
# encoder-only attention does not need KV cache.
|
||||||
continue
|
continue
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
raise NotImplementedError
|
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown attention type: {attn_module.attn_type}")
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
|||||||
Reference in New Issue
Block a user