Add Gemma2 (#592)

This commit is contained in:
Ying Sheng
2024-07-05 09:48:54 -07:00
committed by GitHub
parent d737da5f17
commit 5a57b8addd
7 changed files with 467 additions and 30 deletions

View File

@@ -1,11 +1,9 @@
"""Radix attention."""
import numpy as np
import torch
from torch import nn
from sglang.global_config import global_config
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@@ -21,10 +19,9 @@ class RadixAttention(nn.Module):
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.scaling = scaling
self.layer_id = layer_id
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.controller.model_runner import global_server_args_dict
if not global_server_args_dict.get("disable_flashinfer", False):
@@ -32,29 +29,17 @@ class RadixAttention(nn.Module):
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self.logit_cap = logit_cap if logit_cap > 0 else 0
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap
self.logit_cap = logit_cap if logit_cap is not None else 0
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
context_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k,
v,
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
self.logit_cap,
)
self.store_kv_cache(k, v, input_metadata)
return o
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise NotImplementedError()
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
@@ -75,7 +60,8 @@ class RadixAttention(nn.Module):
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
self.logit_cap,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
@@ -96,7 +82,8 @@ class RadixAttention(nn.Module):
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
self.logit_cap,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)
return o
@@ -108,6 +95,8 @@ class RadixAttention(nn.Module):
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
@@ -118,6 +107,7 @@ class RadixAttention(nn.Module):
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
causal=False,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
@@ -135,6 +125,7 @@ class RadixAttention(nn.Module):
o = input_metadata.flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)