Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
chunk_size: int,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba models.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def _build_chunk_metadata_tensors(
|
||||
self,
|
||||
chunk_size: int,
|
||||
common: M,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute chunk metadata and return as device tensors.
|
||||
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||
"""
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
chunk_size,
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
Reference in New Issue
Block a user