[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
@@ -64,7 +64,7 @@ class BenchArgs:
|
|||||||
run_name: str = "before"
|
run_name: str = "before"
|
||||||
batch_size: Tuple[int] = (1,)
|
batch_size: Tuple[int] = (1,)
|
||||||
input_len: Tuple[int] = (1024,)
|
input_len: Tuple[int] = (1024,)
|
||||||
output_len: Tuple[int] = (4,)
|
output_len: Tuple[int] = (16,)
|
||||||
result_filename: str = ""
|
result_filename: str = ""
|
||||||
correctness_test: bool = False
|
correctness_test: bool = False
|
||||||
# This is only used for correctness test
|
# This is only used for correctness test
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
|
|||||||
scaling: float,
|
scaling: float,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
sliding_window_size: int = -1,
|
||||||
logit_cap: int = -1,
|
logit_cap: int = -1,
|
||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
):
|
):
|
||||||
@@ -46,6 +47,7 @@ class RadixAttention(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||||
self.scaling = scaling
|
self.scaling = scaling
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.sliding_window_size = sliding_window_size
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not global_server_args_dict.get("disable_flashinfer", False)
|
not global_server_args_dict.get("disable_flashinfer", False)
|
||||||
@@ -113,40 +115,52 @@ class RadixAttention(nn.Module):
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||||
|
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||||
|
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
||||||
|
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||||
|
if self.sliding_window_size != -1:
|
||||||
|
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
|
||||||
|
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||||
|
else:
|
||||||
|
if isinstance(prefill_wrapper_ragged, list):
|
||||||
|
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
|
||||||
|
if isinstance(prefill_wrapper_paged, list):
|
||||||
|
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||||
|
|
||||||
if not input_metadata.flashinfer_use_ragged:
|
if not input_metadata.flashinfer_use_ragged:
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
|
window_left=self.sliding_window_size,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o1, s1 = (
|
o1, s1 = prefill_wrapper_ragged.forward_return_lse(
|
||||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
|
window_left=self.sliding_window_size,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if input_metadata.extend_no_prefix:
|
if input_metadata.extend_no_prefix:
|
||||||
o = o1
|
o = o1
|
||||||
else:
|
else:
|
||||||
o2, s2 = (
|
# TODO window attention + radix attention will come up in next PR
|
||||||
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
assert self.sliding_window_size == -1
|
||||||
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
causal=False,
|
causal=False,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
@@ -158,9 +172,16 @@ class RadixAttention(nn.Module):
|
|||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
|
|
||||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||||
|
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
||||||
|
if self.sliding_window_size != -1:
|
||||||
|
decode_wrapper = decode_wrapper[0]
|
||||||
|
else:
|
||||||
|
if isinstance(decode_wrapper, list):
|
||||||
|
decode_wrapper = decode_wrapper[1]
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -154,6 +154,7 @@ class InputMetadata:
|
|||||||
model_runner: "ModelRunner",
|
model_runner: "ModelRunner",
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
|
sliding_window_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
@@ -197,7 +198,7 @@ class InputMetadata:
|
|||||||
):
|
):
|
||||||
flashinfer_use_ragged = True
|
flashinfer_use_ragged = True
|
||||||
ret.init_flashinfer_handlers(
|
ret.init_flashinfer_handlers(
|
||||||
model_runner, prefix_lens, flashinfer_use_ragged
|
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@@ -216,7 +217,11 @@ class InputMetadata:
|
|||||||
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
||||||
|
|
||||||
def init_flashinfer_handlers(
|
def init_flashinfer_handlers(
|
||||||
self, model_runner, prefix_lens, flashinfer_use_ragged
|
self,
|
||||||
|
model_runner,
|
||||||
|
prefix_lens,
|
||||||
|
flashinfer_use_ragged,
|
||||||
|
sliding_window_size=None,
|
||||||
):
|
):
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
self.forward_mode,
|
self.forward_mode,
|
||||||
@@ -225,6 +230,7 @@ class InputMetadata:
|
|||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||||
|
sliding_window_size=sliding_window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -248,6 +254,7 @@ def update_flashinfer_indices(
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper=None,
|
flashinfer_decode_wrapper=None,
|
||||||
flashinfer_use_ragged=False,
|
flashinfer_use_ragged=False,
|
||||||
|
sliding_window_size=None,
|
||||||
):
|
):
|
||||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
@@ -255,6 +262,7 @@ def update_flashinfer_indices(
|
|||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
|
if sliding_window_size is None:
|
||||||
if flashinfer_use_ragged:
|
if flashinfer_use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
else:
|
else:
|
||||||
@@ -317,3 +325,82 @@ def update_flashinfer_indices(
|
|||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
for wrapper_id in range(2):
|
||||||
|
if flashinfer_use_ragged:
|
||||||
|
paged_kernel_lens = prefix_lens
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
|
||||||
|
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||||
|
paged_kernel_lens = torch.minimum(
|
||||||
|
paged_kernel_lens, torch.tensor(sliding_window_size)
|
||||||
|
)
|
||||||
|
kv_start_idx = seq_lens - paged_kernel_lens
|
||||||
|
else:
|
||||||
|
kv_start_idx = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||||
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
|
kv_indices = torch.cat(
|
||||||
|
[
|
||||||
|
model_runner.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices_cpu[i],
|
||||||
|
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
||||||
|
]
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
# CUDA graph uses different flashinfer_decode_wrapper
|
||||||
|
if flashinfer_decode_wrapper is None:
|
||||||
|
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||||
|
|
||||||
|
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
||||||
|
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# extend part
|
||||||
|
qo_indptr = torch.zeros(
|
||||||
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
|
if flashinfer_use_ragged:
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||||
|
wrapper_id
|
||||||
|
].end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||||
|
wrapper_id
|
||||||
|
].begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|||||||
@@ -295,7 +295,16 @@ class ModelRunner:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flashinfer(self):
|
def init_flashinfer(self):
|
||||||
|
self.sliding_window_size = (
|
||||||
|
self.model.get_window_size()
|
||||||
|
if hasattr(self.model, "get_window_size")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if self.server_args.disable_flashinfer:
|
if self.server_args.disable_flashinfer:
|
||||||
|
assert (
|
||||||
|
self.sliding_window_size is None
|
||||||
|
), "turn on flashinfer to support window attention"
|
||||||
self.flashinfer_prefill_wrapper_ragged = None
|
self.flashinfer_prefill_wrapper_ragged = None
|
||||||
self.flashinfer_prefill_wrapper_paged = None
|
self.flashinfer_prefill_wrapper_paged = None
|
||||||
self.flashinfer_decode_wrapper = None
|
self.flashinfer_decode_wrapper = None
|
||||||
@@ -309,12 +318,18 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
use_tensor_cores = False
|
use_tensor_cores = False
|
||||||
|
|
||||||
|
if self.sliding_window_size is None:
|
||||||
self.flashinfer_workspace_buffers = torch.empty(
|
self.flashinfer_workspace_buffers = torch.empty(
|
||||||
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
2,
|
||||||
|
global_config.flashinfer_workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cuda",
|
||||||
)
|
)
|
||||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.flashinfer_prefill_wrapper_ragged = (
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.flashinfer_workspace_buffers[0], "NHD"
|
self.flashinfer_workspace_buffers[0], "NHD"
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
self.flashinfer_workspace_buffers[1], "NHD"
|
self.flashinfer_workspace_buffers[1], "NHD"
|
||||||
)
|
)
|
||||||
@@ -323,6 +338,34 @@ class ModelRunner:
|
|||||||
"NHD",
|
"NHD",
|
||||||
use_tensor_cores=use_tensor_cores,
|
use_tensor_cores=use_tensor_cores,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
workspace_buffers = torch.empty(
|
||||||
|
4,
|
||||||
|
global_config.flashinfer_workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.flashinfer_prefill_wrapper_ragged = []
|
||||||
|
self.flashinfer_prefill_wrapper_paged = []
|
||||||
|
self.flashinfer_decode_wrapper = []
|
||||||
|
for i in range(2):
|
||||||
|
self.flashinfer_prefill_wrapper_ragged.append(
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
workspace_buffers[2 * i + 0], "NHD"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.flashinfer_prefill_wrapper_paged.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffers[2 * i + 1], "NHD"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.flashinfer_decode_wrapper.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffers[2 * i + 0],
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=use_tensor_cores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
@@ -358,7 +401,10 @@ class ModelRunner:
|
|||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self, batch, ForwardMode.DECODE
|
self,
|
||||||
|
batch,
|
||||||
|
ForwardMode.DECODE,
|
||||||
|
sliding_window_size=self.sliding_window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -368,7 +414,10 @@ class ModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: ScheduleBatch):
|
def forward_extend(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self, batch, forward_mode=ForwardMode.EXTEND
|
self,
|
||||||
|
batch,
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
sliding_window_size=self.sliding_window_size,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
@@ -377,7 +426,10 @@ class ModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self, batch, forward_mode=ForwardMode.EXTEND
|
self,
|
||||||
|
batch,
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
sliding_window_size=self.sliding_window_size,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
|
|||||||
@@ -44,6 +44,12 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||||
|
# SGLang assumes exclusive
|
||||||
|
def get_window_size(config):
|
||||||
|
return config.sliding_window - 1
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
"""RMS normalization for Gemma.
|
"""RMS normalization for Gemma.
|
||||||
|
|
||||||
@@ -200,17 +206,14 @@ class Gemma2Attention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
|
||||||
# odd layer, vLLM currently ignores it and uses global attention for
|
|
||||||
# all layers.
|
|
||||||
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
|
|
||||||
del use_sliding_window # Unused.
|
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_idx,
|
layer_id=layer_idx,
|
||||||
|
sliding_window_size=get_window_size(config) if use_sliding_window else -1,
|
||||||
logit_cap=self.config.attn_logit_softcapping,
|
logit_cap=self.config.attn_logit_softcapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -403,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_window_size(self):
|
||||||
|
return get_window_size(self.config)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
@@ -17,9 +17,12 @@ limitations under the License.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ServerArgs:
|
class ServerArgs:
|
||||||
@@ -446,6 +449,15 @@ class ServerArgs:
|
|||||||
assert not (
|
assert not (
|
||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
|
if "gemma-2" in self.model_path.lower():
|
||||||
|
logger.info(
|
||||||
|
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||||
|
)
|
||||||
|
self.disable_radix_cache = True
|
||||||
|
self.disable_regex_jump_forward = True
|
||||||
|
self.disable_flashinfer = False
|
||||||
|
self.disable_cuda_graph = True
|
||||||
|
self.chunked_prefill_size = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
1
python/sglang/test/long_prompt
Normal file
1
python/sglang/test/long_prompt
Normal file
File diff suppressed because one or more lines are too long
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
|
|||||||
"The capital of the United Kindom is",
|
"The capital of the United Kindom is",
|
||||||
"Today is a sunny day and I like",
|
"Today is a sunny day and I like",
|
||||||
"AI is a field of computer science focused on",
|
"AI is a field of computer science focused on",
|
||||||
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
dirpath = os.path.dirname(__file__)
|
||||||
|
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
|
||||||
|
long_prompt = f.read()
|
||||||
|
DEFAULT_PROMPTS.append(long_prompt)
|
||||||
|
|
||||||
NUM_TOP_LOGPROBS = 5
|
NUM_TOP_LOGPROBS = 5
|
||||||
|
|
||||||
|
|
||||||
@@ -125,16 +132,14 @@ class HFRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logits = self.model.forward(input_ids).logits[0]
|
logits = self.model.forward(input_ids).logits[0]
|
||||||
logprobs = F.log_softmax(
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
logits, dim=-1, dtype=torch.float32
|
logprobs, top_indices = torch.topk(
|
||||||
).tolist()
|
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
||||||
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
)
|
||||||
# print("index", index_of_max)
|
# print("index", top_indices)
|
||||||
logprobs = [
|
prefill_logprobs.append(logprobs.tolist())
|
||||||
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
del logits
|
||||||
for token_logprobs in logprobs
|
del logprobs
|
||||||
]
|
|
||||||
prefill_logprobs.append(logprobs)
|
|
||||||
|
|
||||||
out_queue.put(
|
out_queue.put(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
@@ -186,6 +191,7 @@ class SRTRunner:
|
|||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
|
mem_fraction_static=0.7,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -35,18 +35,17 @@ def normal_text(args):
|
|||||||
args.model_path,
|
args.model_path,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
|
device_map="auto",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
m.cuda()
|
m.cuda()
|
||||||
|
|
||||||
print(m)
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The capital of the United Kindom is",
|
"The capital of the United Kindom is",
|
||||||
"Today is a sunny day and I like",
|
"Today is a sunny day and I like",
|
||||||
]
|
]
|
||||||
max_new_tokens = 32
|
max_new_tokens = 16
|
||||||
|
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
if isinstance(p, str):
|
if isinstance(p, str):
|
||||||
@@ -58,10 +57,11 @@ def normal_text(args):
|
|||||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||||
)
|
)
|
||||||
output_str = t.decode(output_ids[0])
|
output_str = t.decode(output_ids[0])
|
||||||
print(output_str)
|
|
||||||
|
|
||||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||||
|
|
||||||
print("prefill logits", prefill_logits)
|
print("prefill logits", prefill_logits)
|
||||||
|
print(output_str)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||||
|
|
||||||
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
|
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||||
|
print("max similarity diff", torch.max(abs(similarities - 1)))
|
||||||
|
|
||||||
|
if hf_logits.shape[0] <= 100:
|
||||||
tolerance = 1e-2
|
tolerance = 1e-2
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
abs(similarities - 1) < tolerance
|
abs(similarities - 1) < tolerance
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import torch
|
|||||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1),
|
||||||
("google/gemma-2-2b", 1),
|
("google/gemma-2-2b", 1, 3),
|
||||||
]
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
@@ -35,6 +35,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
long_context_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
||||||
@@ -53,15 +54,19 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
||||||
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
||||||
|
|
||||||
|
print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
|
||||||
|
if hf_logprobs.shape[0] <= 100:
|
||||||
tolerance = 3e-2
|
tolerance = 3e-2
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
abs(hf_logprobs - srt_logprobs) < tolerance
|
abs(hf_logprobs - srt_logprobs) < tolerance
|
||||||
), f"prefill logprobs not all close"
|
), f"prefill logprobs not all close"
|
||||||
|
|
||||||
|
print(hf_outputs.output_strs)
|
||||||
|
print(srt_outputs.output_strs)
|
||||||
assert hf_outputs.output_strs == srt_outputs.output_strs
|
assert hf_outputs.output_strs == srt_outputs.output_strs
|
||||||
|
|
||||||
def test_prefill_logits(self):
|
def test_prefill_logits_and_output_strs(self):
|
||||||
for model, tp_size in MODELS:
|
for model, tp_size, long_context_tolerance in MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
max_new_tokens = 8
|
max_new_tokens = 8
|
||||||
self.assert_close_prefill_logits_and_output_strs(
|
self.assert_close_prefill_logits_and_output_strs(
|
||||||
@@ -70,6 +75,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
tp_size,
|
tp_size,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
long_context_tolerance=long_context_tolerance,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user