init v0.11.0rc0
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,6 +25,13 @@ class AscendCommonAttentionMetadata:
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
||||
(such as GDN)."""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
@@ -34,7 +45,7 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping_cpu: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
@@ -93,3 +104,34 @@ def split_decodes_and_prefills(
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||
|
||||
Reference in New Issue
Block a user