diff --git a/vllm_kunlun/ops/mamba/causal_conv1d.py b/vllm_kunlun/ops/mamba/causal_conv1d.py index 1e48e36..9080507 100644 --- a/vllm_kunlun/ops/mamba/causal_conv1d.py +++ b/vllm_kunlun/ops/mamba/causal_conv1d.py @@ -8,9 +8,11 @@ from typing import Optional, Union import numpy as np import torch +import torch.nn.functional as F from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton +import xtorch_ops @triton.jit() @@ -357,7 +359,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching tl.store(o_ptrs, acc, mask=mask_1d) -def causal_conv1d_fn( +def causal_conv1d_fn_triton( x: torch.Tensor, weight: torch.Tensor, bias: Union[torch.Tensor, None], @@ -623,6 +625,124 @@ def causal_conv1d_fn( ) return out + +def causal_conv1d_single( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out[..., :(width - 1)].copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_single( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]].unsqueeze(0), + initial_states=conv_states[cache_indices[i]] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + @triton.jit() def _causal_conv1d_update_kernel_xpu( # Pointers to matrices @@ -1075,6 +1195,35 @@ def _causal_conv1d_update_kernel( tl.store(o_ptrs, acc, mask=mask_1d) +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, + conv_state_indices=None +): + _, hidden_size, seq_len = hidden_states.shape + tmp_conv_state = conv_state[conv_state_indices] + state_len = tmp_conv_state.shape[-1] + + hidden_states_new = torch.cat([tmp_conv_state, hidden_states], dim=-1).to(weight.dtype) + cast_conv_state = conv_state.unsqueeze(0) + tmp_hidden_states = hidden_states_new[:, :, -state_len:] + ori_shape = tmp_hidden_states.shape + tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape) + xtorch_ops.reshape_and_cache_flash( + tmp_hidden_states, + tmp_hidden_states, + cast_conv_state, + cast_conv_state, + conv_state_indices) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype).squeeze(-1) + return out + + def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, @@ -1146,6 +1295,16 @@ def causal_conv1d_update( assert weight.stride(1) == 1 # Need this assert cache_seqlens is None # not needed for vLLM - circular buffer + if batch > 1: + return torch_causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation, + conv_state_indices=conv_state_indices + ) + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x stride_w_dim, stride_w_width = weight.stride()