Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return q, k
def clear_linear_attention_cache_for_new_sequences(
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
) -> None:
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills <= 0:
return
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
def linear_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
slope_rate: torch.Tensor,
state_indices_tensor: torch.Tensor,
q_start: int = 0,
q_end: int | None = None,
slot_start: int = 0,
slot_end: int | None = None,
block_size: int = 32,
) -> torch.Tensor:
q = q[q_start:q_end].unsqueeze(2).contiguous()
k = k[q_start:q_end].unsqueeze(2).contiguous()
v = v[q_start:q_end].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[slot_start:slot_end]
return linear_decode_forward_triton(
q, k, v, kv_cache, slope_rate, slot_id, block_size
)
def linear_attention_prefill_and_mix(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
slope_rate: torch.Tensor,
block_size: int,
decode_fn: Callable[..., torch.Tensor],
prefix_fn: Callable[..., torch.Tensor],
layer_idx: int | None = None,
) -> torch.Tensor:
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = prefix_fn(
qs,
ks,
vs,
slice_layer_cache,
slope_rate,
block_size,
layer_idx=layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = decode_fn(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
return linear_attention_prefill_and_mix(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
state_indices_tensor=state_indices_tensor,
attn_metadata=attn_metadata,
slope_rate=self.tp_slope,
block_size=self.BLOCK,
decode_fn=self._decode_infer,
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
layer_idx=self.layer_idx,
)
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache, self.tp_slope, slot_id, 32
hidden = linear_attention_decode(
q,
k,
v,
kv_cache,
self.tp_slope,
state_indices_tensor,
q_start=0,
q_end=attn_metadata.num_decode_tokens,
slot_start=0,
slot_end=attn_metadata.num_decodes,
block_size=32,
)
return hidden
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx
]
q_end = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx + 1
]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx
]
kv_cache[block_to_clear, ...] = 0
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
)
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:

View File

@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
cu_chunk_seqlen=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
)
ssm_outputs.append(scan_out_p)

View File

@@ -289,9 +289,6 @@ def get_temporal_copy_spec(
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):

View File

@@ -1159,7 +1159,7 @@ def causal_conv1d_update(
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
)
# assert num_cache_lines >= batch
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'

View File

@@ -497,6 +497,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
cu_chunk_seqlen=None,
last_chunk_indices=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
@@ -588,6 +590,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
if z is None: