fix multistep bug,remove uselesscodes (#355)
1. remove useluss code in attention.py 2. multistep now using StatefulModelInputForNPU and do not use StatefulModelInput Signed-off-by: new-TonyWang <wangtonyyu222@gmail.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -14,19 +15,26 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import current_stream
|
||||
from vllm.worker.model_runner_base import (
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_frozen_model_input_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||
PythonizationCache,
|
||||
StatefulModelInput)
|
||||
|
||||
from vllm_ascend.utils import current_stream
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class NPUStatefulModelInput(StatefulModelInput):
|
||||
class StatefulModelInputForNPU(StatefulModelInput):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -40,6 +48,66 @@ class NPUStatefulModelInput(StatefulModelInput):
|
||||
torch.npu.Event(blocking=True)
|
||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||
|
||||
# actual frozen model input dataclass passed to _base_model_runner
|
||||
frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "StatefulModelInputForNPU":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
||||
ModelInputForNPUWithSamplingMetadata, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
||||
is only required when prefills are scheduled with decodes to run in
|
||||
multi-step. This advancement/correction is required to account for
|
||||
the conversion of Prefills to Decodes after the first multi-step.
|
||||
"""
|
||||
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
fmi = self.frozen_model_input
|
||||
|
||||
# Truncate input_tokens
|
||||
assert fmi.input_tokens is not None
|
||||
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
||||
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
||||
|
||||
# Update frozen_model_input::input_positons.
|
||||
assert fmi.input_positions is not None
|
||||
assert fmi.input_positions.shape[0] >= self.num_seqs
|
||||
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
||||
num_seqs]
|
||||
|
||||
# Assert unsupported
|
||||
# TODO Uncomment the following codes when NPU supported
|
||||
# assert fmi.lora_mapping is None
|
||||
# assert fmi.lora_requests is not None
|
||||
# assert len(fmi.lora_requests) == 0
|
||||
# assert fmi.prompt_adapter_mapping is None
|
||||
# assert fmi.prompt_adapter_requests is not None
|
||||
# assert len(fmi.prompt_adapter_requests) == 0
|
||||
assert fmi.attn_metadata is not None
|
||||
assert fmi.multi_modal_kwargs is not None
|
||||
assert len(fmi.multi_modal_kwargs) == 0
|
||||
|
||||
self.frozen_model_input = dataclasses.replace(
|
||||
self.frozen_model_input,
|
||||
input_tokens=fmi_new_input_tokens,
|
||||
input_positions=fmi_new_input_positions)
|
||||
|
||||
self.maybe_advance_sampling_metadata(device, pin_memory)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class NPUModelOutput(ModelOutput):
|
||||
@@ -78,7 +146,7 @@ class NPUModelOutput(ModelOutput):
|
||||
return True
|
||||
|
||||
|
||||
class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
||||
@@ -101,7 +169,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
return self._base_model_runner.get_model()
|
||||
|
||||
@functools.cached_property
|
||||
def _copy_stream(self):
|
||||
@@ -109,8 +177,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
return torch.npu.Stream()
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||
model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU:
|
||||
model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
@@ -121,7 +189,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInput:
|
||||
) -> StatefulModelInputForNPU:
|
||||
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
||||
self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
@@ -135,7 +203,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
num_seqs = len(frozen_model_input.seq_lens)
|
||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||
|
||||
model_input = NPUStatefulModelInput(
|
||||
model_input = StatefulModelInputForNPU(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
@@ -145,7 +213,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||
def _async_process_outputs(self, model_input: StatefulModelInputForNPU,
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
@@ -174,7 +242,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
break
|
||||
|
||||
def _final_process_outputs(
|
||||
self, model_input: StatefulModelInput,
|
||||
self, model_input: StatefulModelInputForNPU,
|
||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
@@ -225,7 +293,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: StatefulModelInput,
|
||||
model_input: StatefulModelInputForNPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
@@ -382,8 +450,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
def _advance_step(self, model_input: StatefulModelInputForNPU,
|
||||
out: SamplerOutput) -> StatefulModelInputForNPU:
|
||||
|
||||
model_input.maybe_advance_frozen_model_input(self.device,
|
||||
self.pin_memory)
|
||||
@@ -491,7 +559,7 @@ def deferred_pythonize_logprobs(
|
||||
|
||||
|
||||
def _pythonize_sampler_output(
|
||||
model_input: StatefulModelInput,
|
||||
model_input: StatefulModelInputForNPU,
|
||||
output: SamplerOutput,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user