[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
@@ -64,7 +64,7 @@ class BenchArgs:
|
||||
run_name: str = "before"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (4,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
result_filename: str = ""
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
|
||||
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: int = -1,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
@@ -46,6 +47,7 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.sliding_window_size = sliding_window_size
|
||||
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
@@ -113,39 +115,51 @@ class RadixAttention(nn.Module):
|
||||
return o
|
||||
|
||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||
if self.sliding_window_size != -1:
|
||||
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||
else:
|
||||
if isinstance(prefill_wrapper_ragged, list):
|
||||
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
if not input_metadata.flashinfer_use_ragged:
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
window_left=self.sliding_window_size,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = (
|
||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
o1, s1 = prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
window_left=self.sliding_window_size,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = (
|
||||
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
# TODO window attention + radix attention will come up in next PR
|
||||
assert self.sliding_window_size == -1
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
@@ -158,9 +172,16 @@ class RadixAttention(nn.Module):
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
||||
if self.sliding_window_size != -1:
|
||||
decode_wrapper = decode_wrapper[0]
|
||||
else:
|
||||
if isinstance(decode_wrapper, list):
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
sm_scale=self.scaling,
|
||||
|
||||
@@ -16,7 +16,7 @@ limitations under the License.
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -154,6 +154,7 @@ class InputMetadata:
|
||||
model_runner: "ModelRunner",
|
||||
batch: ScheduleBatch,
|
||||
forward_mode: ForwardMode,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
):
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
@@ -197,7 +198,7 @@ class InputMetadata:
|
||||
):
|
||||
flashinfer_use_ragged = True
|
||||
ret.init_flashinfer_handlers(
|
||||
model_runner, prefix_lens, flashinfer_use_ragged
|
||||
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -216,7 +217,11 @@ class InputMetadata:
|
||||
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
def init_flashinfer_handlers(
|
||||
self, model_runner, prefix_lens, flashinfer_use_ragged
|
||||
self,
|
||||
model_runner,
|
||||
prefix_lens,
|
||||
flashinfer_use_ragged,
|
||||
sliding_window_size=None,
|
||||
):
|
||||
update_flashinfer_indices(
|
||||
self.forward_mode,
|
||||
@@ -225,6 +230,7 @@ class InputMetadata:
|
||||
self.seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||
sliding_window_size=sliding_window_size,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -248,6 +254,7 @@ def update_flashinfer_indices(
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper=None,
|
||||
flashinfer_use_ragged=False,
|
||||
sliding_window_size=None,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
@@ -255,65 +262,145 @@ def update_flashinfer_indices(
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if flashinfer_decode_wrapper is None:
|
||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if sliding_window_size is None:
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if flashinfer_decode_wrapper is None:
|
||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
for wrapper_id in range(2):
|
||||
if flashinfer_use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
paged_kernel_lens, torch.tensor(sliding_window_size)
|
||||
)
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
else:
|
||||
kv_start_idx = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i],
|
||||
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if flashinfer_decode_wrapper is None:
|
||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
||||
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros(
|
||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||
wrapper_id
|
||||
].end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||
wrapper_id
|
||||
].begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
|
||||
@@ -295,7 +295,16 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
|
||||
if self.server_args.disable_flashinfer:
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
@@ -309,20 +318,54 @@ class ModelRunner:
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
if self.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
workspace_buffers = torch.empty(
|
||||
4,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = []
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 1], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
@@ -358,7 +401,10 @@ class ModelRunner:
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, ForwardMode.DECODE
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
|
||||
return self.model.forward(
|
||||
@@ -368,7 +414,10 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -377,7 +426,10 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
|
||||
@@ -44,6 +44,12 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
# SGLang assumes exclusive
|
||||
def get_window_size(config):
|
||||
return config.sliding_window - 1
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
@@ -200,17 +206,14 @@ class Gemma2Attention(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
|
||||
del use_sliding_window # Unused.
|
||||
use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
sliding_window_size=get_window_size(config) if use_sliding_window else -1,
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
@@ -403,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
|
||||
def get_window_size(self):
|
||||
return get_window_size(self.config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -17,9 +17,12 @@ limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
@@ -446,6 +449,15 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info(
|
||||
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||
)
|
||||
self.disable_radix_cache = True
|
||||
self.disable_regex_jump_forward = True
|
||||
self.disable_flashinfer = False
|
||||
self.disable_cuda_graph = True
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
1
python/sglang/test/long_prompt
Normal file
1
python/sglang/test/long_prompt
Normal file
File diff suppressed because one or more lines are too long
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
"AI is a field of computer science focused on",
|
||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||
]
|
||||
|
||||
dirpath = os.path.dirname(__file__)
|
||||
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
|
||||
long_prompt = f.read()
|
||||
DEFAULT_PROMPTS.append(long_prompt)
|
||||
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
|
||||
|
||||
@@ -125,16 +132,14 @@ class HFRunner:
|
||||
)
|
||||
|
||||
logits = self.model.forward(input_ids).logits[0]
|
||||
logprobs = F.log_softmax(
|
||||
logits, dim=-1, dtype=torch.float32
|
||||
).tolist()
|
||||
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
||||
# print("index", index_of_max)
|
||||
logprobs = [
|
||||
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
||||
for token_logprobs in logprobs
|
||||
]
|
||||
prefill_logprobs.append(logprobs)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
logprobs, top_indices = torch.topk(
|
||||
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
||||
)
|
||||
# print("index", top_indices)
|
||||
prefill_logprobs.append(logprobs.tolist())
|
||||
del logits
|
||||
del logprobs
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
@@ -186,6 +191,7 @@ class SRTRunner:
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.7,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -35,18 +35,17 @@ def normal_text(args):
|
||||
args.model_path,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
m.cuda()
|
||||
|
||||
print(m)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = 32
|
||||
max_new_tokens = 16
|
||||
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
@@ -58,10 +57,11 @@ def normal_text(args):
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
print(output_str)
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print("prefill logits", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -53,11 +53,13 @@ class TestEmbeddingModels(unittest.TestCase):
|
||||
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||
|
||||
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||
print("max similarity diff", torch.max(abs(similarities - 1)))
|
||||
|
||||
tolerance = 1e-2
|
||||
assert torch.all(
|
||||
abs(similarities - 1) < tolerance
|
||||
), f"embeddings not all close"
|
||||
if hf_logits.shape[0] <= 100:
|
||||
tolerance = 1e-2
|
||||
assert torch.all(
|
||||
abs(similarities - 1) < tolerance
|
||||
), f"embeddings not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size in MODELS:
|
||||
|
||||
@@ -20,8 +20,8 @@ import torch
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
||||
("google/gemma-2-2b", 1),
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1),
|
||||
("google/gemma-2-2b", 1, 3),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
@@ -35,6 +35,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
long_context_tolerance,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
||||
@@ -53,15 +54,19 @@ class TestGenerationModels(unittest.TestCase):
|
||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||
|
||||
tolerance = 3e-2
|
||||
assert torch.all(
|
||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||
), f"prefill logprobs not all close"
|
||||
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
||||
if hf_logprobs.shape[0] <= 100:
|
||||
tolerance = 3e-2
|
||||
assert torch.all(
|
||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||
), f"prefill logprobs not all close"
|
||||
|
||||
print(hf_outputs.output_strs)
|
||||
print(srt_outputs.output_strs)
|
||||
assert hf_outputs.output_strs == srt_outputs.output_strs
|
||||
|
||||
def test_prefill_logits(self):
|
||||
for model, tp_size in MODELS:
|
||||
def test_prefill_logits_and_output_strs(self):
|
||||
for model, tp_size, long_context_tolerance in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
max_new_tokens = 8
|
||||
self.assert_close_prefill_logits_and_output_strs(
|
||||
@@ -70,6 +75,7 @@ class TestGenerationModels(unittest.TestCase):
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
max_new_tokens,
|
||||
long_context_tolerance=long_context_tolerance,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user