Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
shard_size = loaded_weight.shape[0] // tp_world
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
return
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
return q, k
|
||||
|
||||
|
||||
def clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
) -> None:
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills <= 0:
|
||||
return
|
||||
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
|
||||
def linear_attention_decode(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
q_start: int = 0,
|
||||
q_end: int | None = None,
|
||||
slot_start: int = 0,
|
||||
slot_end: int | None = None,
|
||||
block_size: int = 32,
|
||||
) -> torch.Tensor:
|
||||
q = q[q_start:q_end].unsqueeze(2).contiguous()
|
||||
k = k[q_start:q_end].unsqueeze(2).contiguous()
|
||||
v = v[q_start:q_end].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[slot_start:slot_end]
|
||||
return linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, slope_rate, slot_id, block_size
|
||||
)
|
||||
|
||||
|
||||
def linear_attention_prefill_and_mix(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
decode_fn: Callable[..., torch.Tensor],
|
||||
prefix_fn: Callable[..., torch.Tensor],
|
||||
layer_idx: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
out_slice = prefix_fn(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
slope_rate,
|
||||
block_size,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = decode_fn(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(
|
||||
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def _prefill_and_mix_infer(
|
||||
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
return linear_attention_prefill_and_mix(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
kv_cache=kv_cache,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
slope_rate=self.tp_slope,
|
||||
block_size=self.BLOCK,
|
||||
decode_fn=self._decode_infer,
|
||||
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
||||
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, self.tp_slope, slot_id, 32
|
||||
hidden = linear_attention_decode(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
self.tp_slope,
|
||||
state_indices_tensor,
|
||||
q_start=0,
|
||||
q_end=attn_metadata.num_decode_tokens,
|
||||
slot_start=0,
|
||||
slot_end=attn_metadata.num_decodes,
|
||||
block_size=32,
|
||||
)
|
||||
return hidden
|
||||
|
||||
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
q_end = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx + 1
|
||||
]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
|
||||
- query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
|
||||
Reference in New Issue
Block a user