Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)

This commit is contained in:
Lianmin Zheng
2024-11-16 16:14:23 -08:00
committed by GitHub
parent 976bc302e5
commit edad373135
7 changed files with 198 additions and 83 deletions

View File

@@ -36,6 +36,8 @@ from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional
import torch
import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
@@ -236,25 +238,16 @@ class ForwardBatch:
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_num_tokens = batch.extend_num_tokens
ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
@@ -271,3 +264,72 @@ class ForwardBatch:
model_runner.lora_manager.prepare_lora_batch(ret)
return ret
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
extend_start_loc = torch.empty(
batch_size, dtype=torch.int32, device=extend_seq_lens.device
)
# Launch kernel
compute_position_kernel[(batch_size,)](
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
)
return positions, extend_start_loc
@triton.jit
def compute_position_kernel(
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
tl.store(
positions + cumsum_start + offset,
prefix_len + offset,
mask=offset < seq_len,
)
tl.store(extend_start_loc + pid, cumsum_start)
def compute_position_torch(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
):
positions = torch.concat(
[
torch.arange(
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
)
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
],
axis=0,
)
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc