fix multistep bug,remove uselesscodes (#355)
1. remove useluss code in attention.py 2. multistep now using StatefulModelInputForNPU and do not use StatefulModelInput Signed-off-by: new-TonyWang <wangtonyyu222@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
| LogProbs | ✅ | | | Basic functions available | Need fully test |
|
| LogProbs | ✅ | | | Basic functions available | Need fully test |
|
||||||
| Prompt logProbs | ✅ | | | Basic functions available | Need fully test |
|
| Prompt logProbs | ✅ | | | Basic functions available | Need fully test |
|
||||||
| Async output | ✅ | | | Basic functions available | Need fully test |
|
| Async output | ✅ | | | Basic functions available | Need fully test |
|
||||||
| Multi step scheduler | ✅ | | | Basic functions available | Need fully test |
|
| Multi step scheduler | ✅ | | | Basic functions available | Need fully test, Find more details at [<u> Blog </u>](https://blog.vllm.ai/2024/09/05/perf-update.html#batch-scheduling-multiple-steps-ahead-pr-7000), [<u> RFC </u>](https://github.com/vllm-project/vllm/issues/6854) and [<u>issue</u>](https://github.com/vllm-project/vllm/pull/7000) |
|
||||||
| Best of | ✅ | | | Basic functions available | Need fully test |
|
| Best of | ✅ | | | Basic functions available | Need fully test |
|
||||||
| Beam search | ✅ | | | Basic functions available | Need fully test |
|
| Beam search | ✅ | | | Basic functions available | Need fully test |
|
||||||
| Guided Decoding | ✅ | | | Basic functions available | Find more details at the [<u>issue</u>](https://github.com/vllm-project/vllm-ascend/issues/177) |
|
| Guided Decoding | ✅ | | | Basic functions available | Find more details at the [<u>issue</u>](https://github.com/vllm-project/vllm-ascend/issues/177) |
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import accumulate
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -216,9 +215,6 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
# requests only.
|
# requests only.
|
||||||
max_decode_seq_len: int
|
max_decode_seq_len: int
|
||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
||||||
# so far).
|
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
@@ -234,18 +230,6 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
# Max number of query tokens among request in the batch.
|
|
||||||
max_decode_query_len: Optional[int] = None
|
|
||||||
|
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
|
||||||
# is [4, 6], it is [0, 4, 10].
|
|
||||||
query_start_loc: Optional[torch.Tensor] = None
|
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
||||||
# [4, 6], it is [0, 4, 10].
|
|
||||||
seq_start_loc: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# Self-attention prefill/decode metadata cache
|
# Self-attention prefill/decode metadata cache
|
||||||
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
||||||
@@ -283,18 +267,10 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
or (self.encoder_seq_lens is not None))
|
or (self.encoder_seq_lens is not None))
|
||||||
|
|
||||||
# Compute some attn_metadata fields which default to None.
|
# Compute some attn_metadata fields which default to None.
|
||||||
query_start_loc = (None if self.query_start_loc is None else
|
|
||||||
self.query_start_loc[:self.num_prefills + 1])
|
|
||||||
slot_mapping = (None if self.slot_mapping is None else
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
self.slot_mapping[:self.num_prefill_tokens])
|
self.slot_mapping[:self.num_prefill_tokens])
|
||||||
seq_lens = (None if self.seq_lens is None else
|
seq_lens = (None if self.seq_lens is None else
|
||||||
self.seq_lens[:self.num_prefills])
|
self.seq_lens[:self.num_prefills])
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
||||||
self.seq_lens_tensor[:self.num_prefills])
|
|
||||||
seq_start_loc = (None if self.seq_start_loc is None else
|
|
||||||
self.seq_start_loc[:self.num_prefills + 1])
|
|
||||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
||||||
self.context_lens_tensor[:self.num_prefills])
|
|
||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[:self.num_prefills])
|
self.block_tables[:self.num_prefills])
|
||||||
|
|
||||||
@@ -311,11 +287,7 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
max_decode_query_len=0,
|
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
# Begin encoder & cross attn fields below...
|
# Begin encoder & cross attn fields below...
|
||||||
encoder_seq_lens=self.encoder_seq_lens,
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
@@ -343,8 +315,6 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
self.slot_mapping[self.num_prefill_tokens:])
|
self.slot_mapping[self.num_prefill_tokens:])
|
||||||
seq_lens = (None if self.seq_lens is None else
|
seq_lens = (None if self.seq_lens is None else
|
||||||
self.seq_lens[self.num_prefills:])
|
self.seq_lens[self.num_prefills:])
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
||||||
self.seq_lens_tensor[self.num_prefills:])
|
|
||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[self.num_prefills:])
|
self.block_tables[self.num_prefills:])
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
@@ -357,19 +327,9 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_decode_query_len=self.max_decode_query_len,
|
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
# Batch may be composed of prefill|decodes, adjust query start
|
|
||||||
# indices to refer to the start of decodes. E.g.
|
|
||||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
||||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
||||||
self.query_start_loc[self.num_prefills])
|
|
||||||
if self.query_start_loc is not None else None,
|
|
||||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
||||||
if self.seq_start_loc is not None else None,
|
|
||||||
context_lens_tensor=None,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
# Begin encoder & cross attn fields below...
|
# Begin encoder & cross attn fields below...
|
||||||
encoder_seq_lens=self.encoder_seq_lens,
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
@@ -427,14 +387,6 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
assert self.max_query_len == 1
|
assert self.max_query_len == 1
|
||||||
assert self.max_prefill_seq_len == 0
|
assert self.max_prefill_seq_len == 0
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert self.block_tables is not None
|
assert self.block_tables is not None
|
||||||
assert self.block_tables.shape[0] == num_seqs
|
assert self.block_tables.shape[0] == num_seqs
|
||||||
|
|
||||||
@@ -576,11 +528,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
decode_query_lens = query_lens[self.num_prefills:]
|
|
||||||
if len(decode_query_lens) > 0:
|
|
||||||
max_decode_query_len = max(decode_query_lens)
|
|
||||||
else:
|
|
||||||
max_decode_query_len = 1
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
|
|
||||||
@@ -592,8 +539,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
else:
|
else:
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
||||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
||||||
|
|
||||||
block_tables = make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
@@ -604,27 +549,16 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||||
|
|
||||||
assert device is not None
|
assert device is not None
|
||||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
||||||
device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
placeholder_index_maps = {
|
placeholder_index_maps = {
|
||||||
modality: placeholder_map.index_map()
|
modality: placeholder_map.index_map()
|
||||||
for modality, placeholder_map in
|
for modality, placeholder_map in
|
||||||
self.multimodal_placeholder_maps.items()
|
self.multimodal_placeholder_maps.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
seq_lens_tensor = torch.tensor(seq_lens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
return AscendMetadata(
|
return AscendMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
@@ -635,12 +569,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_decode_query_len=max_decode_query_len,
|
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
query_start_loc=query_start_loc_tensor,
|
|
||||||
seq_start_loc=seq_start_loc_tensor,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -14,19 +15,26 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
|||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
|
from vllm.utils import current_stream
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_frozen_model_input_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||||
PythonizationCache,
|
PythonizationCache,
|
||||||
StatefulModelInput)
|
StatefulModelInput)
|
||||||
|
|
||||||
from vllm_ascend.utils import current_stream
|
|
||||||
from vllm_ascend.worker.model_runner import (
|
from vllm_ascend.worker.model_runner import (
|
||||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=False)
|
@dataclass(frozen=False)
|
||||||
class NPUStatefulModelInput(StatefulModelInput):
|
class StatefulModelInputForNPU(StatefulModelInput):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -40,6 +48,66 @@ class NPUStatefulModelInput(StatefulModelInput):
|
|||||||
torch.npu.Event(blocking=True)
|
torch.npu.Event(blocking=True)
|
||||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||||
|
|
||||||
|
# actual frozen model input dataclass passed to _base_model_runner
|
||||||
|
frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "StatefulModelInputForNPU":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
||||||
|
ModelInputForNPUWithSamplingMetadata, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
||||||
|
"""
|
||||||
|
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
||||||
|
is only required when prefills are scheduled with decodes to run in
|
||||||
|
multi-step. This advancement/correction is required to account for
|
||||||
|
the conversion of Prefills to Decodes after the first multi-step.
|
||||||
|
"""
|
||||||
|
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.frozen_model_input is not None
|
||||||
|
fmi = self.frozen_model_input
|
||||||
|
|
||||||
|
# Truncate input_tokens
|
||||||
|
assert fmi.input_tokens is not None
|
||||||
|
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
||||||
|
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
||||||
|
|
||||||
|
# Update frozen_model_input::input_positons.
|
||||||
|
assert fmi.input_positions is not None
|
||||||
|
assert fmi.input_positions.shape[0] >= self.num_seqs
|
||||||
|
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
||||||
|
num_seqs]
|
||||||
|
|
||||||
|
# Assert unsupported
|
||||||
|
# TODO Uncomment the following codes when NPU supported
|
||||||
|
# assert fmi.lora_mapping is None
|
||||||
|
# assert fmi.lora_requests is not None
|
||||||
|
# assert len(fmi.lora_requests) == 0
|
||||||
|
# assert fmi.prompt_adapter_mapping is None
|
||||||
|
# assert fmi.prompt_adapter_requests is not None
|
||||||
|
# assert len(fmi.prompt_adapter_requests) == 0
|
||||||
|
assert fmi.attn_metadata is not None
|
||||||
|
assert fmi.multi_modal_kwargs is not None
|
||||||
|
assert len(fmi.multi_modal_kwargs) == 0
|
||||||
|
|
||||||
|
self.frozen_model_input = dataclasses.replace(
|
||||||
|
self.frozen_model_input,
|
||||||
|
input_tokens=fmi_new_input_tokens,
|
||||||
|
input_positions=fmi_new_input_positions)
|
||||||
|
|
||||||
|
self.maybe_advance_sampling_metadata(device, pin_memory)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=False)
|
@dataclass(frozen=False)
|
||||||
class NPUModelOutput(ModelOutput):
|
class NPUModelOutput(ModelOutput):
|
||||||
@@ -78,7 +146,7 @@ class NPUModelOutput(ModelOutput):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
|
||||||
# mypy: enable-error-code=type-var
|
# mypy: enable-error-code=type-var
|
||||||
|
|
||||||
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
||||||
@@ -101,7 +169,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self._base_model_runner.get_model()
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def _copy_stream(self):
|
def _copy_stream(self):
|
||||||
@@ -109,8 +177,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
return torch.npu.Stream()
|
return torch.npu.Stream()
|
||||||
|
|
||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU:
|
||||||
model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict(
|
model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
))
|
))
|
||||||
@@ -121,7 +189,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> StatefulModelInput:
|
) -> StatefulModelInputForNPU:
|
||||||
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
||||||
self._base_model_runner.prepare_model_input(
|
self._base_model_runner.prepare_model_input(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
@@ -135,7 +203,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
num_seqs = len(frozen_model_input.seq_lens)
|
num_seqs = len(frozen_model_input.seq_lens)
|
||||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||||
|
|
||||||
model_input = NPUStatefulModelInput(
|
model_input = StatefulModelInputForNPU(
|
||||||
frozen_model_input=frozen_model_input,
|
frozen_model_input=frozen_model_input,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
num_queries=num_queries,
|
num_queries=num_queries,
|
||||||
@@ -145,7 +213,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
|
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
def _async_process_outputs(self, model_input: StatefulModelInputForNPU,
|
||||||
output_proc_callback: Callable):
|
output_proc_callback: Callable):
|
||||||
# Proceed with pythonization and output_proc in order.
|
# Proceed with pythonization and output_proc in order.
|
||||||
# Stop on the first one that fails to pythonize
|
# Stop on the first one that fails to pythonize
|
||||||
@@ -174,7 +242,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def _final_process_outputs(
|
def _final_process_outputs(
|
||||||
self, model_input: StatefulModelInput,
|
self, model_input: StatefulModelInputForNPU,
|
||||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||||
assert model_input.frozen_model_input is not None
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
@@ -225,7 +293,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: StatefulModelInput,
|
model_input: StatefulModelInputForNPU,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
@@ -382,8 +450,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
|||||||
assert seq_group.seq_len is None # Decode
|
assert seq_group.seq_len is None # Decode
|
||||||
assert seq_group.query_len is None # Decode
|
assert seq_group.query_len is None # Decode
|
||||||
|
|
||||||
def _advance_step(self, model_input: StatefulModelInput,
|
def _advance_step(self, model_input: StatefulModelInputForNPU,
|
||||||
out: SamplerOutput) -> StatefulModelInput:
|
out: SamplerOutput) -> StatefulModelInputForNPU:
|
||||||
|
|
||||||
model_input.maybe_advance_frozen_model_input(self.device,
|
model_input.maybe_advance_frozen_model_input(self.device,
|
||||||
self.pin_memory)
|
self.pin_memory)
|
||||||
@@ -491,7 +559,7 @@ def deferred_pythonize_logprobs(
|
|||||||
|
|
||||||
|
|
||||||
def _pythonize_sampler_output(
|
def _pythonize_sampler_output(
|
||||||
model_input: StatefulModelInput,
|
model_input: StatefulModelInputForNPU,
|
||||||
output: SamplerOutput,
|
output: SamplerOutput,
|
||||||
pinned_sampled_token_buffer: torch.Tensor,
|
pinned_sampled_token_buffer: torch.Tensor,
|
||||||
sampled_token_ids: torch.Tensor,
|
sampled_token_ids: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user