Support page size > 1 (#4356)
This commit is contained in:
@@ -36,7 +36,7 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
import deep_gemm
|
||||
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -273,7 +274,6 @@ class Req:
|
||||
"__req__": self
|
||||
}
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
@@ -331,6 +331,8 @@ class Req:
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
self.token_ids_logprob = token_ids_logprob
|
||||
self.temp_scaled_logprobs = False
|
||||
self.top_p_normalized_logprobs = False
|
||||
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
@@ -524,19 +526,23 @@ class ScheduleBatch:
|
||||
model_config: ModelConfig = None
|
||||
forward_mode: ForwardMode = None
|
||||
enable_overlap: bool = False
|
||||
# Tell whether the current running batch is full so that we can skip
|
||||
# the check of whether to prefill new requests.
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
batch_is_full: bool = False
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None # shape: [b], int32
|
||||
input_ids: torch.Tensor = None # shape: [b], int64
|
||||
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
||||
output_ids: torch.Tensor = None # shape: [b], int32
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||
output_ids: torch.Tensor = None # shape: [b], int64
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
@@ -551,6 +557,10 @@ class ScheduleBatch:
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For logits and logprob post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
extend_lens: List[int] = None
|
||||
@@ -560,7 +570,7 @@ class ScheduleBatch:
|
||||
# It comes empty list if logprob is not required.
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# For encoder-decoder
|
||||
# For encoder-decoder architectures
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
@@ -597,6 +607,8 @@ class ScheduleBatch:
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
@@ -604,7 +616,7 @@ class ScheduleBatch:
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
return_logprob=return_logprob,
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
@@ -631,24 +643,83 @@ class ScheduleBatch:
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
error_msg = (
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Prefill out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {extend_num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(
|
||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
logger.error(
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
exit(1)
|
||||
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {len(seq_lens)} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
@@ -699,7 +770,7 @@ class ScheduleBatch:
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Reassign
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
@@ -707,14 +778,14 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||
|
||||
if not encoder_out_cache_loc:
|
||||
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
||||
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
@@ -725,25 +796,38 @@ class ScheduleBatch:
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
# Allocate req slots
|
||||
bs = len(self.reqs)
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
|
||||
# Init tensors
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
seq_lens = [len(r.fill_ids) for r in reqs]
|
||||
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
# Allocate memory
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
||||
|
||||
# Copy prefix and do some basic check
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
seq_lens.append(seq_len)
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
@@ -759,7 +843,7 @@ class ScheduleBatch:
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
@@ -815,60 +899,62 @@ class ScheduleBatch:
|
||||
else:
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
else:
|
||||
last_loc = get_last_loc(
|
||||
self.req_to_token_pool.req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
)
|
||||
out_cache_loc = self.alloc_paged_token_slots_extend(
|
||||
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
||||
)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||
if input_embeds
|
||||
else None
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = prefix_lens
|
||||
self.extend_lens = extend_lens
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
seq_lens_tensor,
|
||||
extend_lens_tensor,
|
||||
out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
||||
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
||||
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
||||
out_cache_loc[pt : pt + extend_lens[i]],
|
||||
)
|
||||
pt += self.extend_lens[i]
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
pt += extend_lens[i]
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
@@ -914,7 +1000,7 @@ class ScheduleBatch:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||
self.tree_cache.evict(bs)
|
||||
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
@@ -939,10 +1025,6 @@ class ScheduleBatch:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
@@ -956,6 +1038,9 @@ class ScheduleBatch:
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
@@ -980,7 +1065,6 @@ class ScheduleBatch:
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
@@ -999,9 +1083,7 @@ class ScheduleBatch:
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(
|
||||
residual_size, self.token_to_kv_pool_allocator.free
|
||||
)
|
||||
self.tree_cache.evict(residual_size)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
@@ -1024,9 +1106,9 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
@@ -1037,6 +1119,8 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
bs = len(self.reqs)
|
||||
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
@@ -1065,33 +1149,39 @@ class ScheduleBatch:
|
||||
self.output_ids.to(torch.int64)
|
||||
)
|
||||
|
||||
# Update fields
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
locs = self.encoder_lens + self.seq_lens
|
||||
self.prepare_encoder_info_decode()
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
locs = self.seq_lens.clone()
|
||||
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
else:
|
||||
last_loc = self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 2
|
||||
]
|
||||
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
||||
self.seq_lens, last_loc
|
||||
)
|
||||
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
|
||||
)
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
chunked_req_to_exclude: Optional[Req] = None,
|
||||
@@ -1345,8 +1435,8 @@ def write_req_to_token_pool_triton(
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
# NOTE: This can be slow for large bs
|
||||
cumsum_start = tl.cast(0, tl.int64)
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
@@ -1363,3 +1453,12 @@ def write_req_to_token_pool_triton(
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
|
||||
return torch.where(
|
||||
prefix_lens_tensor > 0,
|
||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||
torch.full_like(prefix_lens_tensor, -1),
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ class SchedulePolicy:
|
||||
self,
|
||||
policy: str,
|
||||
tree_cache: BasePrefixCache,
|
||||
enable_hierarchical_cache: bool = False,
|
||||
enable_hierarchical_cache: bool,
|
||||
):
|
||||
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
||||
self.tree_cache = tree_cache
|
||||
@@ -85,10 +85,17 @@ class SchedulePolicy:
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=False,
|
||||
)
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
if self.policy == CacheAgnosticPolicy.FCFS:
|
||||
# A shortcut for FCFS
|
||||
return
|
||||
|
||||
policy = self._determine_active_policy(waiting_queue)
|
||||
|
||||
prefix_computed = False
|
||||
@@ -118,7 +125,7 @@ class SchedulePolicy:
|
||||
return prefix_computed
|
||||
|
||||
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
||||
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
|
||||
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
|
||||
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
||||
return CacheAgnosticPolicy.FCFS
|
||||
return self.policy
|
||||
@@ -442,7 +449,7 @@ class PrefillAdder:
|
||||
def add_one_req(
|
||||
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
||||
):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
total_tokens = req.extend_input_len + min(
|
||||
|
||||
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
pyspy_dump_schedulers,
|
||||
set_gpu_proc_affinity,
|
||||
set_random_seed,
|
||||
@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
self.gpu_id = gpu_id
|
||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||
self.page_size = server_args.page_size
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.init_memory_pool_and_cache()
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
||||
# The current forward batch
|
||||
self.cur_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
# The last forward batch
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.last_prefill_stats_tic = time.time()
|
||||
self.return_health_check_ct = 0
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
# Init schedule policy and new token estimation
|
||||
self.policy = SchedulePolicy(
|
||||
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
|
||||
self.schedule_policy,
|
||||
self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
) / global_config.default_new_token_ratio_decay_steps
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
# Tell whether the current running batch is full so that we can skip
|
||||
# the check of whether to prefill new requests.
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
self.batch_is_full = False
|
||||
|
||||
# Init watchdog thread
|
||||
self.watchdog_timeout = server_args.watchdog_timeout
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
# The largest context length (prefill + generation) of a single request
|
||||
self._largest_prefill_decode_len: int = 0
|
||||
self.last_gen_throughput: float = 0.0
|
||||
self.last_input_throughput: float = 0.0
|
||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||
self.spec_num_total_accepted_tokens = 0
|
||||
self.spec_num_total_forward_ct = 0
|
||||
@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
for recv_req in recv_reqs:
|
||||
# If it is a health check generation request and there are running requests, ignore it.
|
||||
if is_health_check_generate_req(recv_req) and (
|
||||
self.chunked_req is not None or self.running_batch is not None
|
||||
self.chunked_req is not None or not self.running_batch.is_empty()
|
||||
):
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
can_run_list: List[Req],
|
||||
running_bs: int,
|
||||
):
|
||||
gap_latency = time.time() - self.last_prefill_stats_tic
|
||||
self.last_prefill_stats_tic = time.time()
|
||||
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
||||
self.num_prefill_tokens = 0
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"KV cache pool leak detected!"
|
||||
"KV cache pool leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.batch_is_full = False
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Filter batch
|
||||
last_bs = self.last_batch.batch_size()
|
||||
self.last_batch.filter_batch()
|
||||
if self.last_batch.batch_size() < last_bs:
|
||||
self.batch_is_full = False
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Merge the new batch into the running batch
|
||||
if not self.last_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
# merge running_batch with prefill batch
|
||||
# Merge running_batch with prefill batch
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
ret = new_batch
|
||||
else:
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
ret = None
|
||||
else:
|
||||
if not self.running_batch.is_empty():
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch
|
||||
ret = self.running_batch if not self.running_batch.is_empty() else None
|
||||
else:
|
||||
ret = None
|
||||
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.batch_is_full or len(self.waiting_queue) == 0
|
||||
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
||||
) and self.chunked_req is None:
|
||||
return None
|
||||
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
running_bs = len(self.running_batch.reqs)
|
||||
if running_bs >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
return None
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
running_bs if self.is_mixed_chunk else 0,
|
||||
)
|
||||
|
||||
is_chunked = self.chunked_req is not None
|
||||
if is_chunked:
|
||||
if self.chunked_req is not None:
|
||||
self.chunked_req.init_next_round_input()
|
||||
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
||||
|
||||
if self.lora_paths:
|
||||
lora_set = (
|
||||
set([req.lora_path for req in self.running_batch.reqs])
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(
|
||||
@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
if self.enable_hierarchical_cache:
|
||||
# Set batch_is_full after making sure there are requests that can be served
|
||||
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
||||
self.running_batch.batch_is_full = len(
|
||||
adder.can_run_list
|
||||
) > 0 or (
|
||||
self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
)
|
||||
else:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
# Update waiting queue
|
||||
@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
# Mixed-style chunked prefill
|
||||
if (
|
||||
self.is_mixed_chunk
|
||||
and self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
||||
):
|
||||
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
||||
@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.running_batch.prepare_for_decode()
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
self.running_batch = ScheduleBatch(
|
||||
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
||||
)
|
||||
else:
|
||||
new_batch.decoding_reqs = None
|
||||
|
||||
@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
batch.filter_batch()
|
||||
if batch.is_empty():
|
||||
self.batch_is_full = False
|
||||
return None
|
||||
batch.batch_is_full = False
|
||||
return batch
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||
@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
|
||||
if batch.batch_size() < initial_bs:
|
||||
self.batch_is_full = False
|
||||
batch.batch_is_full = False
|
||||
|
||||
# Update batch tensors
|
||||
batch.prepare_for_decode()
|
||||
@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
||||
self.cur_batch = None
|
||||
self.last_batch = None
|
||||
self.tree_cache.reset()
|
||||
@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
logging.warning(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
f"#running-req: {len(self.running_batch.reqs)}"
|
||||
)
|
||||
if_success = False
|
||||
return if_success
|
||||
@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = None
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid == recv_req.rid:
|
||||
to_del = i
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
to_del.append(i)
|
||||
break
|
||||
|
||||
if to_del is not None:
|
||||
del self.waiting_queue[to_del]
|
||||
# Sort in reverse order to avoid index issues when deleting
|
||||
for i in sorted(to_del, reverse=True):
|
||||
req = self.waiting_queue.pop(i)
|
||||
logger.debug(f"Abort queued request. {req.rid=}")
|
||||
return
|
||||
|
||||
# Delete requests in the running batch
|
||||
if self.running_batch:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid == recv_req.rid and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
break
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
return
|
||||
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -204,8 +204,17 @@ class SchedulerOutputProcessorMixin:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
# Free the one delayed token
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
# Free the one extra delayed token
|
||||
if self.page_size == 1:
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
else:
|
||||
# Only free when the extra token is in a new page
|
||||
if (
|
||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
) % self.page_size == 0:
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
batch.out_cache_loc[i : i + 1]
|
||||
)
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
|
||||
@@ -103,6 +103,9 @@ class TpModelWorkerClient:
|
||||
self.worker.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def get_kv_cache(self):
|
||||
return self.worker.model_runner.token_to_kv_pool
|
||||
|
||||
def forward_thread_func(self):
|
||||
try:
|
||||
with torch.get_device_module(self.device).stream(self.forward_stream):
|
||||
@@ -203,7 +206,7 @@ class TpModelWorkerClient:
|
||||
-(self.future_token_ids_ct + 1),
|
||||
-(self.future_token_ids_ct + 1 + bs),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_ct = (
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inc_lock_ref(self, node):
|
||||
def inc_lock_ref(self, node: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dec_lock_ref(self, node):
|
||||
def dec_lock_ref(self, node: Any):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evictable_size(self):
|
||||
pass
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def protected_size(self):
|
||||
raise NotImplementedError()
|
||||
return 0
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
pass
|
||||
|
||||
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
entry = self.entries[rid]
|
||||
max_prefix_len = len(key)
|
||||
return entry.value[:max_prefix_len], entry
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_id_len = len(token_ids)
|
||||
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
|
||||
return [], None
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_id_len
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
token_id_len = len(req.fill_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_id_len
|
||||
req.req_pool_idx, : len(req.fill_ids)
|
||||
]
|
||||
|
||||
if req.rid not in self.entries:
|
||||
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
||||
|
||||
entry = self.entries[req.rid]
|
||||
entry.value = kv_indices
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices
|
||||
req.last_node = entry
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
pass
|
||||
|
||||
def inc_lock_ref(self, node):
|
||||
def inc_lock_ref(self, node: Any):
|
||||
return 0
|
||||
|
||||
def dec_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
def protected_size(self):
|
||||
def dec_lock_ref(self, node: Any):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
|
||||
@@ -7,13 +7,13 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache):
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback=None):
|
||||
def evict(self, num_tokens: int):
|
||||
leaves = self._collect_leaves_device()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
|
||||
@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = 1
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index.to(self.device, non_blocking=True)
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
||||
self.free_slots = torch.concat((self.free_slots, free_index))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
|
||||
self.free_slots = torch.arange(
|
||||
1, self.size + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
self.capture_mode = False
|
||||
self.alt_stream = torch.cuda.Stream()
|
||||
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache):
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
if self.capture_mode:
|
||||
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
torch.cuda.current_stream().wait_stream(self.alt_stream)
|
||||
else:
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
|
||||
|
||||
@torch.compile
|
||||
def fused_downcast(
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
store_dtype: torch.dtype,
|
||||
max_fp8: float,
|
||||
min_fp8: float,
|
||||
):
|
||||
cache_k = cache_k / k_scale
|
||||
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
||||
cache_v = cache_v / v_scale
|
||||
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
||||
cache_k = cache_k.to(dtype)
|
||||
cache_v = cache_v.to(dtype)
|
||||
cache_k = cache_k.view(store_dtype)
|
||||
cache_v = cache_v.view(store_dtype)
|
||||
return cache_k, cache_v
|
||||
|
||||
|
||||
# This compiled version is slower in the unit test
|
||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache):
|
||||
with memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
torch.zeros(
|
||||
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
with memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
torch.zeros(
|
||||
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost:
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = torch.empty(
|
||||
self.kv_buffer = torch.zeros(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost:
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
|
||||
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
283
python/sglang/srt/mem_cache/paged_allocator.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Page-aligned memory pool.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_extend_kernel(
|
||||
pre_lens_ptr,
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
max_num_extend_tokens: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
extend_lens = seq_lens - pre_lens
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = tl.load(pre_lens_ptr + pid)
|
||||
extend_len = seq_len - pre_len
|
||||
|
||||
sum_extend_lens = tl.sum(extend_lens)
|
||||
output_start_loc = sum_extend_lens - extend_len
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
||||
tl.int64
|
||||
)
|
||||
tl.store(ret_values, merged_value)
|
||||
|
||||
# Part 1: fill the old partial page
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
num_part1 = (
|
||||
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
|
||||
)
|
||||
offset_one_page = tl.arange(0, page_size)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + offset_one_page,
|
||||
last_loc + 1 + offset_one_page,
|
||||
mask=offset_one_page < num_part1,
|
||||
)
|
||||
if pre_len + num_part1 == seq_len:
|
||||
return
|
||||
|
||||
# Part 2: fill the new full pages
|
||||
num_part2 = (
|
||||
seq_len // page_size * page_size
|
||||
- (pre_len + page_size - 1) // page_size * page_size
|
||||
)
|
||||
|
||||
offset_many_page = tl.arange(0, max_num_extend_tokens)
|
||||
page_start = tl.load(
|
||||
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + offset_many_page,
|
||||
page_start * page_size + offset_many_page % page_size,
|
||||
mask=offset_many_page < num_part2,
|
||||
)
|
||||
if pre_len + num_part1 + num_part2 == seq_len:
|
||||
return
|
||||
|
||||
# Part 3: fill the new partial page
|
||||
num_part3 = seq_len - seq_len // page_size * page_size
|
||||
start_loc = tl.load(
|
||||
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
|
||||
)
|
||||
tl.store(
|
||||
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
|
||||
start_loc * page_size + offset_one_page,
|
||||
mask=offset_one_page < num_part3,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def alloc_decode_kernel(
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
|
||||
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + pid)
|
||||
pre_len = seq_len - 1
|
||||
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (pre_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
|
||||
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
|
||||
pre_len + page_size - 1
|
||||
) // page_size
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
tl.store(ret_values, sum_num_new_pages)
|
||||
|
||||
if num_page_start_loc_self == 0:
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
tl.store(out_indices + pid, last_loc + 1)
|
||||
else:
|
||||
page = tl.load(free_page_ptr + new_page_start_loc)
|
||||
tl.store(out_indices + pid, page * page_size)
|
||||
|
||||
|
||||
class PagedTokenToKVPoolAllocator:
|
||||
"""
|
||||
An allocator managing the indices to kv cache data.
|
||||
|
||||
This class has the same interface as `TokenToKVPoolAllocator` but the output
|
||||
of one request is always page-aligned.
|
||||
|
||||
TODO: fuse last_loc into the kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.num_pages = size // page_size
|
||||
|
||||
self.free_pages = None
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
self.clear()
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
|
||||
self._kvcache = kvcache
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_pages) * self.page_size
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
alloc_extend_kernel[(bs,)](
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
if free_index.numel() == 0:
|
||||
return
|
||||
|
||||
if self.is_not_in_free_group:
|
||||
free_page_indices = torch.unique(free_index // self.page_size)
|
||||
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
def free_group_end(self):
|
||||
self.is_not_in_free_group = True
|
||||
if self.free_group:
|
||||
self.free(torch.concat(self.free_group))
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -67,7 +68,7 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match(key0: List, key1: List):
|
||||
def _key_match_page_size1(key0: List, key1: List):
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
|
||||
return i
|
||||
|
||||
|
||||
def _key_match_paged(key0: List, key1: List, page_size: int):
|
||||
min_len = min(len(key0), len(key1))
|
||||
|
||||
i = 0
|
||||
while i < min_len:
|
||||
if key0[i : i + page_size] != key1[i : i + page_size]:
|
||||
break
|
||||
i += page_size
|
||||
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
disable: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = page_size
|
||||
self.disable = disable
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
if self.page_size == 1:
|
||||
self.key_match_fn = _key_match_page_size1
|
||||
self.get_child_key_fn = lambda key: key[0]
|
||||
else:
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = lambda key: tuple(key[:page_size])
|
||||
self.reset()
|
||||
|
||||
##### Public API #####
|
||||
@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
if self.disable or len(key) == 0:
|
||||
return (
|
||||
torch.empty(
|
||||
(0,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
self.root_node,
|
||||
)
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(key) // self.page_size * self.page_size
|
||||
key = key[:page_aligned_len]
|
||||
|
||||
value, last_node = self._match_prefix_helper(self.root_node, key)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
value = torch.empty((0,), dtype=torch.int32, device=self.device)
|
||||
return value, last_node
|
||||
|
||||
def insert(self, key: List, value=None):
|
||||
@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_finished_req(self, req: Req):
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
if token_ids is None:
|
||||
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_ids_len = len(token_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_ids_len
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
new_prefix_len = self.insert(
|
||||
token_ids[:page_aligned_len], page_aligned_kv_indices
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = req.fill_ids
|
||||
|
||||
token_ids = req.fill_ids
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
req.prefix_indices = new_indices
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
if self.page_size != 1:
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
|
||||
def total_size(self):
|
||||
return self._total_size_helper()
|
||||
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
def evict(self, num_tokens: int):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
evict_callback(x.value)
|
||||
self.token_to_kv_pool_allocator.free(x.value)
|
||||
num_evicted += len(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
|
||||
# protected size refers to the size of the cache that is locked
|
||||
return self.protected_size_
|
||||
|
||||
def all_values_flatten(self):
|
||||
values = []
|
||||
|
||||
def _dfs_helper(node: TreeNode):
|
||||
for _, child in node.children.items():
|
||||
values.append(child.value)
|
||||
_dfs_helper(child)
|
||||
|
||||
_dfs_helper(self.root_node)
|
||||
return torch.concat(values)
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node: TreeNode, key: List):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
value = []
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
child = node.children[key[0]]
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
child = node.children[child_key]
|
||||
child.last_access_time = time.time()
|
||||
prefix_len = _key_match(child.key, key)
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
|
||||
value.append(child.value)
|
||||
node = child
|
||||
key = key[prefix_len:]
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
return value, node
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len]: child}
|
||||
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
||||
new_node.parent = child.parent
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
|
||||
child.parent = new_node
|
||||
child.key = child.key[split_len:]
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[key[0]] = new_node
|
||||
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
total_prefix_length = 0
|
||||
while len(key) > 0 and key[0] in node.children.keys():
|
||||
node = node.children[key[0]]
|
||||
while len(key) > 0 and child_key in node.children.keys():
|
||||
node = node.children[child_key]
|
||||
node.last_access_time = time.time()
|
||||
prefix_len = _key_match(node.key, key)
|
||||
prefix_len = self.key_match_fn(node.key, key)
|
||||
total_prefix_length += prefix_len
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
|
||||
new_node = self._split_node(node.key, node, prefix_len)
|
||||
node = new_node
|
||||
|
||||
if len(key):
|
||||
child_key = self.get_child_key_fn(key)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.key = key
|
||||
new_node.value = value
|
||||
node.children[key[0]] = new_node
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return total_prefix_length
|
||||
|
||||
@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
|
||||
current_node.key[:10],
|
||||
f"r={current_node.lock_ref}",
|
||||
)
|
||||
for _, child in current_node.children.items():
|
||||
for key, child in current_node.children.items():
|
||||
stack.append((child, current_indent + 2))
|
||||
|
||||
assert key == self.get_child_key_fn(
|
||||
child.key
|
||||
), f"{key=}, {self.get_child_key_fn(child.key)=}"
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(None, None, False)
|
||||
tree = RadixCache(None, None, page_size=1, disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
|
||||
@@ -264,11 +264,15 @@ class CudaGraphRunner:
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
|
||||
@@ -38,12 +38,12 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||
# It is also called "prefill" in common terminology.
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
@@ -153,6 +152,12 @@ class ForwardBatch:
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
temperature: torch.Tensor = None
|
||||
top_p_normalized_logprobs: bool = False
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
@@ -189,7 +194,7 @@ class ForwardBatch:
|
||||
|
||||
# Attention backend
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
token_to_kv_pool: KVCache = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# For DP attention
|
||||
@@ -229,7 +234,6 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -417,8 +421,8 @@ def compute_position_kernel(
|
||||
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
# NOTE: This can be slow for large bs
|
||||
cumsum_start = tl.cast(0, tl.int64)
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_seq_lens + i)
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
@@ -430,7 +431,7 @@ class ModelRunner:
|
||||
self.model_config.model_path = model_path
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support the DefaultModelLoader for now
|
||||
# Only support DefaultModelLoader for now
|
||||
loader = get_model_loader(load_config)
|
||||
if not isinstance(loader, DefaultModelLoader):
|
||||
message = f"Failed to get model loader: {loader}."
|
||||
@@ -732,6 +733,7 @@ class ModelRunner:
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
@@ -742,6 +744,7 @@ class ModelRunner:
|
||||
elif self.server_args.enable_double_sparsity:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
@@ -753,6 +756,7 @@ class ModelRunner:
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
@@ -762,12 +766,21 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
if self.page_size == 1:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
|
||||
@@ -220,6 +220,8 @@ class ServerArgs:
|
||||
else:
|
||||
self.chunked_prefill_size = 8192
|
||||
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
|
||||
# Set cuda graph max batch size
|
||||
if self.cuda_graph_max_bs is None:
|
||||
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
||||
|
||||
@@ -1554,6 +1554,13 @@ def set_cuda_arch():
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
|
||||
|
||||
|
||||
def next_power_of_2(n: int):
|
||||
return 1 << (n - 1).bit_length() if n > 0 else 1
|
||||
|
||||
|
||||
setattr(triton, "next_power_of_2", next_power_of_2)
|
||||
|
||||
|
||||
def add_prefix(name: str, prefix: str) -> str:
|
||||
"""Add a weight path prefix to a module name.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user