Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
gk: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
q: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
|
||||
h: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
scale: float,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
B, T, H, K, V = *q.shape, v.shape[-1]
|
||||
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
@@ -1208,7 +1208,7 @@ def chunk_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scale is None:
|
||||
|
||||
Reference in New Issue
Block a user