Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)
This commit is contained in:
@@ -36,6 +36,8 @@ from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
|
||||
@@ -236,25 +238,16 @@ class ForwardBatch:
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
@@ -271,3 +264,72 @@ class ForwardBatch:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
):
|
||||
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
||||
batch_size = extend_seq_lens.shape[0]
|
||||
positions = torch.empty(
|
||||
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
||||
)
|
||||
extend_start_loc = torch.empty(
|
||||
batch_size, dtype=torch.int32, device=extend_seq_lens.device
|
||||
)
|
||||
|
||||
# Launch kernel
|
||||
compute_position_kernel[(batch_size,)](
|
||||
positions,
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
)
|
||||
|
||||
return positions, extend_start_loc
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_position_kernel(
|
||||
positions,
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_seq_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
tl.store(
|
||||
positions + cumsum_start + offset,
|
||||
prefix_len + offset,
|
||||
mask=offset < seq_len,
|
||||
)
|
||||
tl.store(extend_start_loc + pid, cumsum_start)
|
||||
|
||||
|
||||
def compute_position_torch(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
||||
):
|
||||
positions = torch.concat(
|
||||
[
|
||||
torch.arange(
|
||||
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
||||
)
|
||||
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
Reference in New Issue
Block a user