Support page size > 1 (#4356)
This commit is contained in:
@@ -49,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -273,7 +274,6 @@ class Req:
|
||||
"__req__": self
|
||||
}
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
@@ -331,6 +331,8 @@ class Req:
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
self.token_ids_logprob = token_ids_logprob
|
||||
self.temp_scaled_logprobs = False
|
||||
self.top_p_normalized_logprobs = False
|
||||
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
@@ -524,19 +526,23 @@ class ScheduleBatch:
|
||||
model_config: ModelConfig = None
|
||||
forward_mode: ForwardMode = None
|
||||
enable_overlap: bool = False
|
||||
# Tell whether the current running batch is full so that we can skip
|
||||
# the check of whether to prefill new requests.
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
batch_is_full: bool = False
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None # shape: [b], int32
|
||||
input_ids: torch.Tensor = None # shape: [b], int64
|
||||
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
||||
output_ids: torch.Tensor = None # shape: [b], int32
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||
output_ids: torch.Tensor = None # shape: [b], int64
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
@@ -551,6 +557,10 @@ class ScheduleBatch:
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For logits and logprob post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
extend_lens: List[int] = None
|
||||
@@ -560,7 +570,7 @@ class ScheduleBatch:
|
||||
# It comes empty list if logprob is not required.
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# For encoder-decoder
|
||||
# For encoder-decoder architectures
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
@@ -597,6 +607,8 @@ class ScheduleBatch:
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
@@ -604,7 +616,7 @@ class ScheduleBatch:
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
return_logprob=return_logprob,
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
@@ -631,24 +643,83 @@ class ScheduleBatch:
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
error_msg = (
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Prefill out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {extend_num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
):
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(
|
||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
logger.error(
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
exit(1)
|
||||
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {len(seq_lens)} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
@@ -699,7 +770,7 @@ class ScheduleBatch:
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Reassign
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
@@ -707,14 +778,14 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||
|
||||
if not encoder_out_cache_loc:
|
||||
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
||||
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
@@ -725,25 +796,38 @@ class ScheduleBatch:
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
# Allocate req slots
|
||||
bs = len(self.reqs)
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
|
||||
# Init tensors
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
seq_lens = [len(r.fill_ids) for r in reqs]
|
||||
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
# Allocate memory
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
||||
|
||||
# Copy prefix and do some basic check
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
seq_lens.append(seq_len)
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
@@ -759,7 +843,7 @@ class ScheduleBatch:
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
@@ -815,60 +899,62 @@ class ScheduleBatch:
|
||||
else:
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
else:
|
||||
last_loc = get_last_loc(
|
||||
self.req_to_token_pool.req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
)
|
||||
out_cache_loc = self.alloc_paged_token_slots_extend(
|
||||
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
||||
)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||
if input_embeds
|
||||
else None
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = prefix_lens
|
||||
self.extend_lens = extend_lens
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
seq_lens_tensor,
|
||||
extend_lens_tensor,
|
||||
out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
||||
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
||||
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
||||
out_cache_loc[pt : pt + extend_lens[i]],
|
||||
)
|
||||
pt += self.extend_lens[i]
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
pt += extend_lens[i]
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
@@ -914,7 +1000,7 @@ class ScheduleBatch:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||
self.tree_cache.evict(bs)
|
||||
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
@@ -939,10 +1025,6 @@ class ScheduleBatch:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
@@ -956,6 +1038,9 @@ class ScheduleBatch:
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
@@ -980,7 +1065,6 @@ class ScheduleBatch:
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
@@ -999,9 +1083,7 @@ class ScheduleBatch:
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(
|
||||
residual_size, self.token_to_kv_pool_allocator.free
|
||||
)
|
||||
self.tree_cache.evict(residual_size)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
@@ -1024,9 +1106,9 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
@@ -1037,6 +1119,8 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
bs = len(self.reqs)
|
||||
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
@@ -1065,33 +1149,39 @@ class ScheduleBatch:
|
||||
self.output_ids.to(torch.int64)
|
||||
)
|
||||
|
||||
# Update fields
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
locs = self.encoder_lens + self.seq_lens
|
||||
self.prepare_encoder_info_decode()
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
locs = self.seq_lens.clone()
|
||||
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
# Allocate memory
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
else:
|
||||
last_loc = self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 2
|
||||
]
|
||||
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
||||
self.seq_lens, last_loc
|
||||
)
|
||||
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
|
||||
)
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
chunked_req_to_exclude: Optional[Req] = None,
|
||||
@@ -1345,8 +1435,8 @@ def write_req_to_token_pool_triton(
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
# NOTE: This can be slow for large bs
|
||||
cumsum_start = tl.cast(0, tl.int64)
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
@@ -1363,3 +1453,12 @@ def write_req_to_token_pool_triton(
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
|
||||
return torch.where(
|
||||
prefix_lens_tensor > 0,
|
||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||
torch.full_like(prefix_lens_tensor, -1),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user