Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)

This commit is contained in:
Lianmin Zheng
2024-11-16 16:14:23 -08:00
committed by GitHub
parent 976bc302e5
commit edad373135
7 changed files with 198 additions and 83 deletions

View File

@@ -34,6 +34,8 @@ import logging
from typing import List, Optional, Tuple, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
@@ -615,12 +617,12 @@ class ScheduleBatch:
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
# Allocate memory
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
@@ -638,10 +640,6 @@ class ScheduleBatch:
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len],
)
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
@@ -652,8 +650,8 @@ class ScheduleBatch:
extend_logprob_start_len = req.extend_input_len - 1
req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len
req.is_retracted = False
pre_lens.append(pre_len)
# Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -665,7 +663,6 @@ class ScheduleBatch:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
@@ -676,9 +673,33 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
# Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
write_req_to_token_pool_triton[(bs,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
pre_lens,
self.seq_lens,
extend_lens,
self.out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
# The triton kernel is equivalent to the following python code.
# self.req_to_token_pool.write(
# (req.req_pool_idx, slice(pre_len, seq_len)),
# out_cache_loc[pt : pt + req.extend_input_len],
# )
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
@@ -1025,6 +1046,9 @@ class ScheduleBatch:
)
def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
@@ -1104,3 +1128,40 @@ class ModelWorkerBatch:
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)