Suppport qwen model and solve some problems (#75)

This commit is contained in:
Arcmoon
2024-01-23 12:14:51 +08:00
committed by GitHub
parent e08bca2840
commit 63e97e5e4c
7 changed files with 274 additions and 4 deletions

View File

@@ -61,7 +61,6 @@ class RadixAttention(nn.Module):
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(),