[Feature] adapt to uva buffer and main2main (#6657)
### What this PR does / why we need it?
vllm model runner v2 use uva buffer to prepare input data, but npu
doesn't support uva yet, this pr implement a uvawrapper class to mimic
gpu's uva backend. what's more, this pr make some modifications to adapt
to the newer main branch.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM main:
13397841ab
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -42,17 +42,21 @@ def get_attn_mask_builder(device: torch.device):
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
*,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
max_query_len: int,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_np: np.ndarray,
|
||||
num_computed_tokens_cpu: torch.Tensor | None,
|
||||
max_seq_len: int,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np: np.ndarray | None = None,
|
||||
num_computed_tokens_cpu: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
attn_state: Any | None = None,
|
||||
graph_pad_size: int = -1,
|
||||
@@ -61,9 +65,13 @@ def build_attn_metadata(
|
||||
) -> dict[str, Any]:
|
||||
"""Build attention metadata for Ascend NPUs."""
|
||||
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
||||
max_query_len = int(query_start_loc_cpu.max())
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
# seq_lens_np is used for ascend npus, it maybe None in spec_decode case,
|
||||
# we fill it with max_seq_len in case `attn_metadata_builder.build` raise
|
||||
# an error.
|
||||
if seq_lens_np is None:
|
||||
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings = slot_mappings.to(torch.int32)
|
||||
@@ -77,7 +85,7 @@ def build_attn_metadata(
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=seq_lens_cpu[:num_reqs],
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens=seq_lens[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
|
||||
Reference in New Issue
Block a user