Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
This commit is contained in:
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
logit_cap: int = -1,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.qk_head_dim = head_dim
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
):
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
else:
|
||||
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
token_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
|
||||
Reference in New Issue
Block a user