Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)
This commit is contained in:
@@ -34,6 +34,8 @@ import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -615,12 +617,12 @@ class ScheduleBatch:
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
|
||||
# Allocate memory
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
already_computed = (
|
||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||
@@ -638,10 +640,6 @@ class ScheduleBatch:
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
out_cache_loc[pt : pt + req.extend_input_len],
|
||||
)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
@@ -652,8 +650,8 @@ class ScheduleBatch:
|
||||
extend_logprob_start_len = req.extend_input_len - 1
|
||||
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
pt += req.extend_input_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -665,7 +663,6 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
@@ -676,9 +673,33 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
# The triton kernel is equivalent to the following python code.
|
||||
# self.req_to_token_pool.write(
|
||||
# (req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
# out_cache_loc[pt : pt + req.extend_input_len],
|
||||
# )
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
@@ -1025,6 +1046,9 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
|
||||
_ = self.seq_lens[0].item()
|
||||
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
@@ -1104,3 +1128,40 @@ class ModelWorkerBatch:
|
||||
for x, y in self.req_to_token_pool_records
|
||||
]
|
||||
self.sampling_info.to(device)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid)
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < (seq_len - pre_len)
|
||||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||
tl.store(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ offset
|
||||
+ pre_len,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user