Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)
This commit is contained in:
@@ -56,6 +56,7 @@ class BenchArgs:
|
||||
gen_output_len: int = 256
|
||||
disable_ignore_eos: bool = False
|
||||
seed: int = 1
|
||||
do_not_exit: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -143,6 +144,11 @@ class BenchArgs:
|
||||
help="Disable ignore EOS token",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--do-not-exit",
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -309,3 +315,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
throughput_test(server_args, bench_args)
|
||||
|
||||
while bench_args.do_not_exit:
|
||||
pass
|
||||
|
||||
@@ -314,7 +314,6 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
@@ -445,7 +444,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.max_context_len,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
wrapper.end_forward()
|
||||
@@ -474,7 +473,6 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
@@ -599,7 +597,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.max_context_len,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
@@ -638,10 +636,11 @@ def create_flashinfer_kv_indices_triton(
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
max_context_len: tl.constexpr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
@@ -652,15 +651,15 @@ def create_flashinfer_kv_indices_triton(
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
req_to_token_ptr += req_pool_index * max_context_len
|
||||
kv_indices_ptr += kv_indices_offset
|
||||
|
||||
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
||||
st_offset = tl.arange(0, BLOCK_SIZE)
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = ld_offset < kv_end
|
||||
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
||||
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
||||
ld_offset += BLOCK_SIZE
|
||||
st_offset += BLOCK_SIZE
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
@@ -62,21 +62,21 @@ class LogitsMetadata:
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
extend_logprob_pruned_lens_cpu = None
|
||||
|
||||
if forward_batch.return_logprob:
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
for extend_len, start_len in zip(
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
forward_batch.extend_logprob_start_lens_cpu,
|
||||
)
|
||||
]
|
||||
else:
|
||||
return_top_logprob = False
|
||||
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
for extend_len, start_len in zip(
|
||||
forward_batch.extend_seq_lens,
|
||||
forward_batch.extend_logprob_start_lens_cpu,
|
||||
)
|
||||
]
|
||||
else:
|
||||
extend_logprob_pruned_lens_cpu = None
|
||||
return cls(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
|
||||
@@ -34,6 +34,8 @@ import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -615,12 +617,12 @@ class ScheduleBatch:
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
|
||||
# Allocate memory
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
already_computed = (
|
||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||
@@ -638,10 +640,6 @@ class ScheduleBatch:
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
out_cache_loc[pt : pt + req.extend_input_len],
|
||||
)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
@@ -652,8 +650,8 @@ class ScheduleBatch:
|
||||
extend_logprob_start_len = req.extend_input_len - 1
|
||||
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
pt += req.extend_input_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -665,7 +663,6 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
@@ -676,9 +673,33 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
# The triton kernel is equivalent to the following python code.
|
||||
# self.req_to_token_pool.write(
|
||||
# (req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
# out_cache_loc[pt : pt + req.extend_input_len],
|
||||
# )
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
@@ -1025,6 +1046,9 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
|
||||
_ = self.seq_lens[0].item()
|
||||
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
@@ -1104,3 +1128,40 @@ class ModelWorkerBatch:
|
||||
for x, y in self.req_to_token_pool_records
|
||||
]
|
||||
self.sampling_info.to(device)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid)
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < (seq_len - pre_len)
|
||||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||
tl.store(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ offset
|
||||
+ pre_len,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ class TpModelWorkerClient:
|
||||
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
||||
self.max_running_requests = self.worker.max_running_requests
|
||||
self.device = self.worker.device
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Init future mappings
|
||||
self.future_token_ids_ct = 0
|
||||
@@ -73,12 +74,6 @@ class TpModelWorkerClient:
|
||||
)
|
||||
self.forward_thread.start()
|
||||
|
||||
self.copy_queue = Queue()
|
||||
self.copy_thread = threading.Thread(
|
||||
target=self.copy_thread_func,
|
||||
)
|
||||
self.copy_thread.start()
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
@@ -104,12 +99,11 @@ class TpModelWorkerClient:
|
||||
@torch.inference_mode()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
self.has_inflight_batch = False
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
self.has_inflight_batch = True
|
||||
self.launch_event = threading.Event()
|
||||
copy_event = torch.cuda.Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
@@ -142,39 +136,29 @@ class TpModelWorkerClient:
|
||||
)
|
||||
)
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event = torch.cuda.Event(blocking=True)
|
||||
copy_event.record()
|
||||
|
||||
self.launch_event.set()
|
||||
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
||||
|
||||
def copy_thread_func(self):
|
||||
while True:
|
||||
copy_event, logits_output, next_token_ids = self.copy_queue.get()
|
||||
if not copy_event:
|
||||
break
|
||||
while not copy_event.query():
|
||||
time.sleep(1e-5)
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
||||
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
||||
|
||||
def resulve_batch_result(self, bid: int):
|
||||
logits_output, next_token_ids = self.output_queue.get()
|
||||
if self.has_inflight_batch:
|
||||
# Wait until the batch is launched
|
||||
self.launch_event.wait()
|
||||
copy_event, logits_output, next_token_ids = self.output_queue.get()
|
||||
while not copy_event.query():
|
||||
time.sleep(1e-5)
|
||||
self.launch_event.wait()
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
|
||||
@@ -36,6 +36,8 @@ from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
|
||||
@@ -236,25 +238,16 @@ class ForwardBatch:
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
@@ -271,3 +264,72 @@ class ForwardBatch:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
):
|
||||
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
||||
batch_size = extend_seq_lens.shape[0]
|
||||
positions = torch.empty(
|
||||
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
||||
)
|
||||
extend_start_loc = torch.empty(
|
||||
batch_size, dtype=torch.int32, device=extend_seq_lens.device
|
||||
)
|
||||
|
||||
# Launch kernel
|
||||
compute_position_kernel[(batch_size,)](
|
||||
positions,
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
)
|
||||
|
||||
return positions, extend_start_loc
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_position_kernel(
|
||||
positions,
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_seq_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
tl.store(
|
||||
positions + cumsum_start + offset,
|
||||
prefix_len + offset,
|
||||
mask=offset < seq_len,
|
||||
)
|
||||
tl.store(extend_start_loc + pid, cumsum_start)
|
||||
|
||||
|
||||
def compute_position_torch(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
||||
):
|
||||
positions = torch.concat(
|
||||
[
|
||||
torch.arange(
|
||||
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
||||
)
|
||||
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
@@ -73,7 +73,7 @@ class SamplingBatchInfo:
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=top_ks.max().item() <= 1,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user