sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
224
benchmark/kernels/all_reduce/benchmark_mscclpp.py
Normal file
224
benchmark/kernels/all_reduce/benchmark_mscclpp.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""For Now, MSCCL is only supported on TP16 and TP8 case
|
||||
|
||||
export WORLD_SIZE=1
|
||||
export RANK=0
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export MASTER_PORT=12345
|
||||
|
||||
torchrun --nproc_per_node gpu \
|
||||
--nnodes $WORLD_SIZE \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_mscclpp_all_reduce,
|
||||
)
|
||||
|
||||
|
||||
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||
dist.all_reduce(torch_input, group=group)
|
||||
return torch_input
|
||||
|
||||
|
||||
def msccl_allreduce(
|
||||
msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator
|
||||
) -> torch.Tensor:
|
||||
return msccl_comm.all_reduce(msccl_input)
|
||||
|
||||
|
||||
def pynccl_allreduce(
|
||||
msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||
) -> torch.Tensor:
|
||||
pynccl_comm.all_reduce(msccl_input)
|
||||
return msccl_input
|
||||
|
||||
|
||||
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||
graph_input = inp_randn.clone()
|
||||
with graph_capture() as graph_capture_context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(graph_loop):
|
||||
graph_out = func(graph_input)
|
||||
|
||||
graph.replay()
|
||||
func_output = graph_out.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for _ in range(test_loop):
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||
graph.reset()
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||
eager_input = inp_randn.clone()
|
||||
eager_output = func(eager_input)
|
||||
func_output = eager_output.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
func(eager_input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(test_loop):
|
||||
func(eager_input)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def get_torch_prof_ctx(do_prof: bool):
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
)
|
||||
if do_prof
|
||||
else nullcontext()
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("tabulate not installed, skipping table printing")
|
||||
tabulate = None
|
||||
|
||||
|
||||
def print_markdown_table(data):
|
||||
if tabulate is not None:
|
||||
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||
return
|
||||
headers = data[0].keys()
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
rows = []
|
||||
for item in data:
|
||||
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||
rows.append(row)
|
||||
markdown_table = "\n".join([header_row, separator] + rows)
|
||||
print(markdown_table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank % 8)
|
||||
device = torch.cuda.current_device()
|
||||
set_mscclpp_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=rank % 8,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||
pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm
|
||||
dist.barrier()
|
||||
profile = False
|
||||
dtype = torch.bfloat16
|
||||
ctx = get_torch_prof_ctx(profile)
|
||||
result = []
|
||||
|
||||
with ctx:
|
||||
for i in range(10, 20):
|
||||
sz = 2**i
|
||||
if sz * dtype.itemsize > 2**20:
|
||||
break
|
||||
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
|
||||
memory = torch.empty_like(inp_randn)
|
||||
memory_out = torch.empty_like(memory)
|
||||
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||
)
|
||||
msccl_eager_output, msccl_eager_time = _bench_eager_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
msccl_graph_output, msccl_graph_time = _bench_graph_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||
_, pynccl_graph_time = _bench_graph_time(
|
||||
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||
)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_graph_output)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_eager_output)
|
||||
result.append(
|
||||
{
|
||||
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||
"torch eager time": torch_eager_time,
|
||||
"msccl eager time": msccl_eager_time,
|
||||
"msccl graph time": msccl_graph_time,
|
||||
"pynccl graph time": pynccl_graph_time,
|
||||
}
|
||||
)
|
||||
if rank == 0:
|
||||
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||
if rank == 0:
|
||||
print_markdown_table(result)
|
||||
if profile:
|
||||
prof_dir = f"prof/msccl"
|
||||
os.makedirs(prof_dir, exist_ok=True)
|
||||
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||
@@ -0,0 +1,403 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import cudnn
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.utils import should_use_tensor_core
|
||||
|
||||
|
||||
def benchmark_forward(
|
||||
fn,
|
||||
*inputs,
|
||||
repeats=10,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
def amp_wrapper(*inputs, **kwinputs):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
fn(*inputs, **kwinputs)
|
||||
|
||||
t = benchmark.Timer(
|
||||
stmt="fn_amp(*inputs, **kwinputs)",
|
||||
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
return t, m
|
||||
|
||||
|
||||
def time_fwd(func, *args, **kwargs):
|
||||
time_f = benchmark_forward(func, *args, **kwargs)
|
||||
return time_f[1].mean * 1e6
|
||||
|
||||
|
||||
def decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits,
|
||||
warmup=10,
|
||||
):
|
||||
|
||||
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
|
||||
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
|
||||
o = torch.empty_like(q)
|
||||
total_tokens = batch_size * kv_len
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
|
||||
max_len_in_batch = kv_len
|
||||
sm_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
f = time_fwd(
|
||||
decode_attention_fwd,
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
return f, o
|
||||
|
||||
|
||||
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=dtype,
|
||||
num_attention_heads=head_num_q,
|
||||
num_kv_heads=head_num_kv,
|
||||
)
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
class FlashinferAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype,
|
||||
warmup=10,
|
||||
):
|
||||
total_tokens = batch_size * kv_len
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full(
|
||||
(batch_size,), 1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=dtype,
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
o = flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, head_num_q, head_dim), kv_data
|
||||
)
|
||||
|
||||
f = time_fwd(
|
||||
flashinfer_decode_wrapper.forward,
|
||||
q.contiguous().view(-1, head_num_q, head_dim),
|
||||
kv_data,
|
||||
)
|
||||
|
||||
return f, o
|
||||
|
||||
return FlashinferAttention
|
||||
|
||||
|
||||
def convert_to_cudnn_type(torch_type):
|
||||
if torch_type == torch.float16:
|
||||
return cudnn.data_type.HALF
|
||||
elif torch_type == torch.bfloat16:
|
||||
return cudnn.data_type.BFLOAT16
|
||||
elif torch_type == torch.float32:
|
||||
return cudnn.data_type.FLOAT
|
||||
elif torch_type == torch.int32:
|
||||
return cudnn.data_type.INT32
|
||||
elif torch_type == torch.int64:
|
||||
return cudnn.data_type.INT64
|
||||
else:
|
||||
raise ValueError("Unsupported tensor data type.")
|
||||
|
||||
|
||||
def decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
|
||||
):
|
||||
# Prepare data: continuous q,k,v
|
||||
dims_q = (batch_size, head_num_q, 1, head_dim)
|
||||
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
|
||||
q_gpu = q.as_strided(dims_q, strides_q)
|
||||
o_gpu = (
|
||||
torch.empty(batch_size * head_num_q * head_dim)
|
||||
.half()
|
||||
.cuda()
|
||||
.as_strided(dims_q, strides_q)
|
||||
)
|
||||
|
||||
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
|
||||
strides_kv = (
|
||||
kv_len * head_num_kv * head_dim,
|
||||
head_dim,
|
||||
head_num_kv * head_dim,
|
||||
1,
|
||||
)
|
||||
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
|
||||
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
|
||||
|
||||
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
|
||||
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
|
||||
attn_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
# Prepare data: paged k,v
|
||||
block_size = 1
|
||||
blocks_per_batch = math.ceil(kv_len / block_size)
|
||||
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
|
||||
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||
page_table_k_gpu = (
|
||||
torch.linspace(
|
||||
0,
|
||||
batch_size * blocks_per_batch - 1,
|
||||
batch_size * blocks_per_batch,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
.reshape(blocks_per_batch, 1, batch_size, 1)
|
||||
.transpose(0, 2)
|
||||
)
|
||||
page_table_v_gpu = page_table_k_gpu.clone()
|
||||
|
||||
graph = cudnn.pygraph(
|
||||
io_data_type=convert_to_cudnn_type(dtype),
|
||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||
compute_data_type=cudnn.data_type.FLOAT,
|
||||
)
|
||||
|
||||
q = graph.tensor_like(q_gpu)
|
||||
container_k = graph.tensor_like(container_k_gpu)
|
||||
container_v = graph.tensor_like(container_v_gpu)
|
||||
page_table_k = graph.tensor_like(page_table_k_gpu)
|
||||
page_table_v = graph.tensor_like(page_table_v_gpu)
|
||||
|
||||
seq_len_q = graph.tensor_like(seq_len_q_gpu)
|
||||
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
|
||||
|
||||
o, _ = graph.sdpa(
|
||||
name="sdpa",
|
||||
q=q,
|
||||
k=container_k, # Container K: non contiguous container with K blocks
|
||||
v=container_v, # Container V: non contiguous container with V blocks
|
||||
is_inference=True,
|
||||
attn_scale=attn_scale,
|
||||
use_causal_mask=False,
|
||||
use_padding_mask=True,
|
||||
seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv,
|
||||
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
|
||||
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
|
||||
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
|
||||
)
|
||||
|
||||
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
|
||||
|
||||
graph.validate()
|
||||
graph.build_operation_graph()
|
||||
graph.create_execution_plans([cudnn.heur_mode.A])
|
||||
graph.check_support()
|
||||
graph.build_plans()
|
||||
|
||||
workspace = torch.empty(
|
||||
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
variant_pack = {
|
||||
q: q_gpu,
|
||||
container_k: container_k_gpu,
|
||||
container_v: container_v_gpu,
|
||||
page_table_k: page_table_k_gpu,
|
||||
page_table_v: page_table_v_gpu,
|
||||
seq_len_q: seq_len_q_gpu,
|
||||
seq_len_kv: seq_len_kv_gpu,
|
||||
o: o_gpu,
|
||||
}
|
||||
|
||||
for _ in range(warmup):
|
||||
graph.execute(variant_pack, workspace)
|
||||
|
||||
f = time_fwd(
|
||||
graph.execute,
|
||||
variant_pack,
|
||||
workspace,
|
||||
)
|
||||
|
||||
return f, o_gpu.squeeze(dim=2)
|
||||
|
||||
|
||||
def calculate_diff():
|
||||
|
||||
dtype = torch.float16
|
||||
batch_size = 64
|
||||
kv_len = 4096
|
||||
head_num_q = 64
|
||||
head_num_kv = 8
|
||||
head_dim = 128
|
||||
|
||||
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
|
||||
kv_data = (
|
||||
torch.randn(
|
||||
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||
),
|
||||
torch.randn(
|
||||
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||
),
|
||||
)
|
||||
|
||||
_, output_sglang = decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits=8,
|
||||
)
|
||||
|
||||
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
|
||||
_, output_flashinfer = attn_flashinfer(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
|
||||
_, output_cudnn = decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
|
||||
print(f"SGLang output={output_sglang}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"cuDNN output={output_cudnn}")
|
||||
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
|
||||
print("✅ SGLang[Triton] and FlashInfer match")
|
||||
else:
|
||||
print("❌ SGLang[Triton] and FlashInfer differ")
|
||||
|
||||
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
|
||||
print("✅ SGLang[Triton] and cuDNN match")
|
||||
else:
|
||||
print("❌ SGLang[Triton] and cuDNN differ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff()
|
||||
|
||||
head_dim = 128
|
||||
dtype = torch.float16
|
||||
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||
kv_len_range = [2**i for i in range(6, 13, 1)]
|
||||
configs = list(itertools.product(batch_size_range, kv_len_range))
|
||||
|
||||
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
|
||||
attn_flashinfer = decode_attention_flashinfer(
|
||||
dtype, head_num_q, head_num_kv
|
||||
).apply
|
||||
for batch_size, kv_len in configs:
|
||||
q = torch.randn(
|
||||
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
kv_data = (
|
||||
torch.randn(
|
||||
batch_size * kv_len,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
),
|
||||
torch.randn(
|
||||
batch_size * kv_len,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
),
|
||||
)
|
||||
us_cudnn, output_cudnn = decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
us_sglang, output_sglang = decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits=8,
|
||||
)
|
||||
us_flashinfer, _ = attn_flashinfer(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
print(
|
||||
head_num_q,
|
||||
" ",
|
||||
head_num_kv,
|
||||
" ",
|
||||
batch_size,
|
||||
" ",
|
||||
kv_len,
|
||||
" ",
|
||||
us_cudnn,
|
||||
" ",
|
||||
us_sglang,
|
||||
" ",
|
||||
us_flashinfer,
|
||||
)
|
||||
218
benchmark/kernels/deepep/deepep_utils.py
Normal file
218
benchmark/kernels/deepep/deepep_utils.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_dist(local_rank: int, num_local_ranks: int, args):
|
||||
ip = args.master_addr
|
||||
port = args.master_port
|
||||
num_nodes = args.nnodes
|
||||
node_rank = args.node_rank
|
||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{ip}:{port}",
|
||||
world_size=num_nodes * num_local_ranks,
|
||||
rank=node_rank * num_local_ranks + local_rank,
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
return (
|
||||
dist.get_rank(),
|
||||
dist.get_world_size(),
|
||||
dist.new_group(list(range(num_local_ranks * num_nodes))),
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double() + 1, y.double() + 1
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return (1 - sim).item()
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
def inplace_unique(x: torch.Tensor, num_slots: int):
|
||||
assert x.dim() == 2
|
||||
mask = x < 0
|
||||
x_padded = x.masked_fill(mask, num_slots)
|
||||
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
|
||||
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
|
||||
bin_count = bin_count[:, :num_slots]
|
||||
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
|
||||
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
|
||||
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
|
||||
x[:, :].fill_(-1)
|
||||
valid_len = min(num_slots, x.size(1))
|
||||
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
|
||||
|
||||
|
||||
def create_grouped_scores(
|
||||
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
|
||||
):
|
||||
num_tokens, num_experts = scores.shape
|
||||
scores = scores.view(num_tokens, num_groups, -1)
|
||||
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
|
||||
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
|
||||
return (scores * mask).view(num_tokens, num_experts)
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Flush L2
|
||||
cache.zero_()
|
||||
|
||||
# Testing
|
||||
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
for i in range(num_tests):
|
||||
# Record
|
||||
start_events[i].record()
|
||||
fn()
|
||||
end_events[i].record()
|
||||
if post_fn is not None:
|
||||
post_fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = np.array(
|
||||
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
|
||||
)[1:]
|
||||
return np.average(times), np.min(times), np.max(times)
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: Optional[str] = None,
|
||||
barrier_comm_profiling: bool = False,
|
||||
):
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
) as prof:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
|
||||
for _ in range(num_tests):
|
||||
fn()
|
||||
prof.step()
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
prof.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, "")) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def hash_tensor(t: torch.Tensor):
|
||||
return t.view(torch.int64).sum().item()
|
||||
476
benchmark/kernels/deepep/tuning_deepep.py
Normal file
476
benchmark/kernels/deepep/tuning_deepep.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
|
||||
|
||||
"""
|
||||
Example usage:
|
||||
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
|
||||
Then check `deepep_tuned.json`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
create_grouped_scores,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_local_ranks: int,
|
||||
num_ranks: int,
|
||||
num_nodes: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
args,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
|
||||
4096,
|
||||
7168,
|
||||
min(num_nodes, 4),
|
||||
8,
|
||||
(256 // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=num_topk_groups, dim=-1, sorted=False
|
||||
).indices
|
||||
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
||||
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
|
||||
1
|
||||
]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
rdma_rank_idx = rank_idx // num_local_ranks
|
||||
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
||||
inplace_unique(rdma_rank_idx, num_nodes)
|
||||
|
||||
# RDMA dispatch counts
|
||||
rdma_idx = topk_idx // (num_experts // num_nodes)
|
||||
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rdma_idx, num_nodes)
|
||||
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
(
|
||||
ref_num_tokens_per_rank,
|
||||
ref_num_tokens_per_rdma_rank,
|
||||
ref_num_tokens_per_expert,
|
||||
ref_is_token_in_rank,
|
||||
_,
|
||||
) = buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = recv_gbl_rank_prefix_sum[i].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
recv_gbl_rank_prefix_sum = handle[-4]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
combine_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
output_data = {}
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
rdma_send_bytes = (
|
||||
(dispatch_bf16_rdma_send_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_rdma_send_bytes
|
||||
)
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for rdma_chunk_size in range(4, 33, 4):
|
||||
config_kwargs = {
|
||||
"num_sms": num_sms,
|
||||
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
|
||||
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
|
||||
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
|
||||
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
|
||||
}
|
||||
config = deep_ep.Config(**config_kwargs)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
config_kwargs,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
is_fp8 = isinstance(current_x, tuple)
|
||||
if is_fp8:
|
||||
output_data["normal_dispatch"] = deepcopy(best_results[3])
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1], best_results[2]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0],
|
||||
best_dispatch_results[1],
|
||||
nvl_buffer_size,
|
||||
best_dispatch_results[2],
|
||||
rdma_buffer_size,
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for rdma_chunk_size in range(8, 33, 4):
|
||||
config_kwargs = {
|
||||
"num_sms": num_sms,
|
||||
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
|
||||
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
|
||||
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
|
||||
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
|
||||
}
|
||||
config = deep_ep.Config(**config_kwargs)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
config_kwargs,
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
output_data["normal_combine"] = deepcopy(best_results[3])
|
||||
|
||||
if rank == 0 and local_rank == 0:
|
||||
_write_output(args, output_data)
|
||||
|
||||
|
||||
def _write_output(args, output_data):
|
||||
text = json.dumps(output_data, indent=4)
|
||||
output_path = args.output_path
|
||||
print(f"Write to {output_path} with {text}")
|
||||
Path(output_path).write_text(text)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int, args):
|
||||
num_nodes = args.nnodes
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks, args)
|
||||
|
||||
num_sms = args.num_sms
|
||||
num_qps_per_rank = num_sms // 2
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
int(1e9),
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (num_sms,):
|
||||
test_main(
|
||||
i,
|
||||
local_rank,
|
||||
num_local_ranks,
|
||||
num_ranks,
|
||||
num_nodes,
|
||||
rank,
|
||||
buffer,
|
||||
group,
|
||||
args,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-sms", type=int, default=24)
|
||||
parser.add_argument("--output-path", type=str, default="deepep_tuned.json")
|
||||
parser.add_argument("--nnodes", type=int, default=1)
|
||||
parser.add_argument("--node-rank", type=int, default=0)
|
||||
parser.add_argument("--master-addr", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--master-port", type=int, default=8361)
|
||||
args = parser.parse_args()
|
||||
print(f"Start system with {args=}")
|
||||
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(
|
||||
test_loop, args=(num_processes, args), nprocs=num_processes
|
||||
)
|
||||
19
benchmark/kernels/deepseek/README.md
Normal file
19
benchmark/kernels/deepseek/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## DeepSeek kernels benchmark
|
||||
|
||||
|
||||
### Prerequisites
|
||||
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
|
||||
|
||||
### Benchmark
|
||||
- `benchmark_deepgemm_fp8_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- `benchmark_deepgemm_fp8_group_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
|
||||
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.
|
||||
400
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
Normal file
400
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
|
||||
def tl_gemm(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
in_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
):
|
||||
assert in_dtype in [
|
||||
"e4m3_float8",
|
||||
], "Currently only e4m3_float8 is supported"
|
||||
assert out_dtype in [
|
||||
"bfloat16",
|
||||
"float16",
|
||||
], "Currently only bfloat16 and float16 are supported"
|
||||
|
||||
TILE_SIZE = (128, 128, 128)
|
||||
block_M = TILE_SIZE[0]
|
||||
block_N = TILE_SIZE[1]
|
||||
block_K = TILE_SIZE[2]
|
||||
|
||||
A_shape = (M, K)
|
||||
Scales_A_shape = (M, T.ceildiv(K, block_K))
|
||||
B_shape = (N, K)
|
||||
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
|
||||
A_shared_shape = (block_M, block_K)
|
||||
B_shared_shape = (block_N, block_K)
|
||||
C_shared_shape = (block_M, block_N)
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
A: T.Buffer(A_shape, in_dtype),
|
||||
scales_a: T.Buffer(Scales_A_shape, "float32"),
|
||||
B: T.Buffer(B_shape, in_dtype),
|
||||
scales_b: T.Buffer(Scales_B_shape, "float32"),
|
||||
C: T.Buffer((M, N), out_dtype),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
|
||||
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
|
||||
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
|
||||
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
|
||||
Scale_C_shared = T.alloc_shared((block_M), "float32")
|
||||
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
|
||||
# Improve L2 Cache
|
||||
T.use_swizzle(panel_size=10)
|
||||
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=4):
|
||||
# Load A into shared memory
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
# Load B into shared memory
|
||||
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||
# Load scale into shared memory
|
||||
Scale_B = scales_b[bx, k]
|
||||
for i in T.Parallel(block_M):
|
||||
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
# Promote to enable 2xAcc
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||
T.clear(C_local)
|
||||
# TMA store
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_sglang(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""SGLang implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run SGLang kernel
|
||||
out = w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_vllm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""vLLM implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run vLLM kernel
|
||||
out = vllm_w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int):
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
out_deepgemm = fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
out_sglang = fp8_gemm_sglang(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||
)
|
||||
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
out_tilelang = tilelang_kernel(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
|
||||
)
|
||||
|
||||
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
|
||||
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
|
||||
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"SGLang output: {out_sglang[0, 0:5]}")
|
||||
print(f"TileLang output: {out_tilelang[0, 0:5]}")
|
||||
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
|
||||
|
||||
sglang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_sglang_match = torch.allclose(
|
||||
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
|
||||
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
configs.append((m, n, k, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "tp_size"],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "sglang", "tilelang"],
|
||||
line_names=["DeepGEMM", "SGLang", "TileLang"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, tp_size, provider):
|
||||
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Preprocess data before benchmarking
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sglang":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_sglang(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # tilelang
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: tilelang_kernel(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
# Print shape-specific results with TFLOPS
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_gemm/",
|
||||
help="Path to save fp8 gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(64, 512, 7168) # Small test
|
||||
calculate_diff(64, 7168, 16384) # Medium test
|
||||
calculate_diff(64, 18432, 7168) # Large test
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
486
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
Normal file
486
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
Normal file
@@ -0,0 +1,486 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import shared functionality from the regular GEMM benchmark
|
||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||
per_block_cast_to_fp8,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def construct_grouped_and_flat_fp8(
|
||||
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
|
||||
) -> Tuple[
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
|
||||
torch.Tensor, # output
|
||||
torch.Tensor, # reference output
|
||||
]:
|
||||
# Verify input shapes
|
||||
m, k = x.shape
|
||||
n, k_y = y.shape
|
||||
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
|
||||
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
|
||||
# Reshape inputs for grouped processing
|
||||
m_per_group = m // num_groups
|
||||
x_grouped = x.view(num_groups, m_per_group, k)
|
||||
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
|
||||
|
||||
# Initialize output tensors
|
||||
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
|
||||
|
||||
# Quantize grouped tensors
|
||||
x_fp8_grouped = (
|
||||
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
y_fp8_grouped = (
|
||||
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
|
||||
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
|
||||
|
||||
# Quantize flat tensors
|
||||
x_fp8_flat = per_token_cast_to_fp8(x)
|
||||
y_fp8_flat = per_block_cast_to_fp8(y)
|
||||
|
||||
# For non-masked input, merge the group and M dims in output
|
||||
if not is_masked:
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0].view(-1, k),
|
||||
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
|
||||
)
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier for testing
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0],
|
||||
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
)
|
||||
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
|
||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||
|
||||
|
||||
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
|
||||
# custom kernel based on the Triton tutorial.
|
||||
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
@triton.jit
|
||||
def fp8_gemm_group_triton_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Pointers to scaling factors
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension.
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Strides for scaling factors
|
||||
stride_a_scale_m,
|
||||
stride_a_scale_k,
|
||||
stride_b_scale_n,
|
||||
stride_b_scale_k,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
|
||||
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
||||
"""
|
||||
# Map program ids to the block of C it should compute
|
||||
pid_group = tl.program_id(axis=0) # Group ID
|
||||
pid_n = tl.program_id(axis=1) # N dimension ID
|
||||
|
||||
# Compute the M block ID within this group
|
||||
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
||||
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
||||
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
||||
|
||||
# Create pointers for the first blocks of A and B
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# Initialize accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Main loop
|
||||
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
k_offset = k_block * BLOCK_SIZE_K
|
||||
|
||||
# Load the next block of A and B, with masks
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
|
||||
|
||||
# Calculate indices for scaling factors for this K block
|
||||
a_scale_ptrs = a_scale_ptr + (
|
||||
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
|
||||
)
|
||||
b_scale_ptrs = b_scale_ptr + (
|
||||
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
||||
)
|
||||
|
||||
# Perform matrix multiplication in FP8
|
||||
res = tl.dot(a, b)
|
||||
|
||||
# Load scaling factors for the current block
|
||||
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
# Apply scaling factors to the accumulated result
|
||||
accumulator += res * a_scale * b_scale
|
||||
|
||||
# Advance pointers
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Convert to bfloat16 for output
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
|
||||
# Write back the result
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||
"""
|
||||
Perform matrix multiplication with FP8 inputs and proper scaling.
|
||||
|
||||
Args:
|
||||
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
||||
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
||||
c: Output tensor in BF16 format
|
||||
num_groups: Number of groups for grouped GEMM
|
||||
|
||||
Returns:
|
||||
Result tensor in BF16 format
|
||||
"""
|
||||
# Unpack the tuples
|
||||
a, a_scale = a_tuple
|
||||
b, b_scale = b_tuple
|
||||
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
|
||||
# Configure block sizes - must be multiples of 32 for TMA alignment
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_K = 128
|
||||
|
||||
# Calculate grid dimensions
|
||||
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
||||
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
||||
|
||||
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
||||
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
||||
|
||||
fp8_gemm_group_triton_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a_scale,
|
||||
b_scale,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
a_scale.stride(0),
|
||||
1, # Stride in the K dimension may be 1
|
||||
b_scale.stride(0),
|
||||
1 if b_scale.dim() > 1 else 0,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M=num_groups,
|
||||
)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
||||
print(f"Shape (m={m}, n={n}, k={k}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
out_deepgemm = out.clone()
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
||||
)
|
||||
|
||||
fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out_deepgemm,
|
||||
m_indices,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Prepare inputs for Triton
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
||||
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
||||
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"Torch output: {out_torch[0, 0:5]}")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"Triton output: {out_triton[0, 0:5]}")
|
||||
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
||||
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
||||
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
||||
|
||||
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
||||
triton_torch_diff = calc_diff(out_triton, out_torch)
|
||||
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
||||
|
||||
DIFF_THRESHOLD = 0.001
|
||||
all_match = (
|
||||
deepgemm_torch_diff < DIFF_THRESHOLD
|
||||
and triton_torch_diff < DIFF_THRESHOLD
|
||||
and deepgemm_triton_diff < DIFF_THRESHOLD
|
||||
)
|
||||
if all_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(
|
||||
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [2048, 4096]
|
||||
group_sizes = [4, 8]
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
for num_groups in group_sizes:
|
||||
configs.append((m, n, k, num_groups, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "num_groups", "tp_size"],
|
||||
x_vals=[config for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "triton"],
|
||||
line_names=["DeepGEMM", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, num_groups, tp_size, provider):
|
||||
print(
|
||||
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
|
||||
)
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1)
|
||||
.expand(num_groups, m_per_group)
|
||||
.contiguous()
|
||||
.view(-1)
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
# Prepare inputs for Triton
|
||||
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_group_triton(
|
||||
(a, a_scale),
|
||||
(b, b_scale),
|
||||
c,
|
||||
num_groups,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_group_gemm/",
|
||||
help="Path to save deepgemm fp8 group gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(8192, 7168, 4096, 4)
|
||||
calculate_diff(8192, 2048, 7168, 4)
|
||||
calculate_diff(4096, 7168, 4096, 8)
|
||||
calculate_diff(4096, 2048, 7168, 8)
|
||||
calculate_diff(4096, 576, 7168, 8)
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
29
benchmark/kernels/fbgemm/README.md
Normal file
29
benchmark/kernels/fbgemm/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
## Benchmark FBGEMM Grouped GEMM
|
||||
|
||||
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
|
||||
|
||||
### Requirements
|
||||
|
||||
```shell
|
||||
pip install fbgemm-gpu-genai
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
||||
```
|
||||
|
||||
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
|
||||
|
||||
```shell
|
||||
grouped-gemm-performance:
|
||||
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
|
||||
0 256.0 3704.841339 3042.626402 2254.725030
|
||||
1 512.0 3691.426346 3029.065684 2269.504543
|
||||
2 1024.0 3653.938629 2258.471467 2358.319020
|
||||
3 2048.0 3596.644313 2271.611904 2476.895397
|
||||
4 4096.0 3468.496435 2231.283986 2179.473910
|
||||
```
|
||||
|
||||
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
|
||||
516
benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py
Normal file
516
benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
|
||||
quantize_fp8_row,
|
||||
triton_quantize_fp8_row,
|
||||
)
|
||||
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
||||
grouped_gemm as fbgemm_grouped_gemm,
|
||||
)
|
||||
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton as sglang_grouped_gemm,
|
||||
)
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
num_groups = config.ffn_config.moe_num_experts
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.intermediate_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
num_groups = config.num_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
]:
|
||||
num_groups = config.n_routed_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
num_groups = config.text_config.num_local_experts
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
num_groups = config.num_local_experts
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
else:
|
||||
num_groups = config.num_local_experts
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
shape_configs = {
|
||||
"num_groups": num_groups,
|
||||
"hidden_size": config.hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
||||
torch.manual_seed(42)
|
||||
|
||||
tokens_per_group = batch_size // num_groups
|
||||
m_sizes = torch.full(
|
||||
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
base_weights = torch.randn(
|
||||
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
|
||||
w_sglang = base_weights
|
||||
|
||||
c_fbgemm = torch.empty(
|
||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
c_sglang = torch.empty(
|
||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
|
||||
for i in range(1, num_groups + 1):
|
||||
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
||||
|
||||
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
|
||||
|
||||
return (
|
||||
x,
|
||||
w_fbgemm,
|
||||
w_sglang,
|
||||
c_fbgemm,
|
||||
c_sglang,
|
||||
m_sizes,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
)
|
||||
|
||||
|
||||
def create_fp8_test_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
|
||||
):
|
||||
"""
|
||||
Create test data for FP8 grouped GEMM operations.
|
||||
|
||||
Args:
|
||||
batch_size: Total batch size
|
||||
num_groups: Number of groups
|
||||
hidden_size: Hidden dimension size
|
||||
intermediate_size: Intermediate dimension size
|
||||
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
|
||||
|
||||
Returns:
|
||||
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
|
||||
For cutlass: (x, wq, w_scale, m_sizes)
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
tokens_per_group = batch_size // num_groups
|
||||
|
||||
# Create weight matrices for each group
|
||||
w_list = []
|
||||
for _ in range(num_groups):
|
||||
w = torch.randn(
|
||||
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
||||
)
|
||||
w_list.append(w)
|
||||
|
||||
# Quantize weights using quantize_fp8_row for each group
|
||||
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
|
||||
|
||||
if backend == "triton":
|
||||
# Triton format: concatenated weights
|
||||
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
|
||||
|
||||
# Create m_sizes as int32 for triton
|
||||
m_sizes = torch.full(
|
||||
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Create and quantize input
|
||||
x_fp16 = torch.randn(
|
||||
batch_size, hidden_size, dtype=torch.float16, device="cuda"
|
||||
)
|
||||
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
|
||||
x_scale = x_scale.view(batch_size, -1)
|
||||
|
||||
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
||||
|
||||
elif backend == "cutlass":
|
||||
# CUTLASS format: stacked weights
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
|
||||
# Create m_sizes as int64 for cutlass
|
||||
m_values = [tokens_per_group] * num_groups
|
||||
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
|
||||
|
||||
# Create input data - separate for each group then concat
|
||||
x_list = []
|
||||
for _ in range(num_groups):
|
||||
x = torch.randn(
|
||||
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
|
||||
)
|
||||
x_list.append(x)
|
||||
|
||||
# Concatenate inputs into single tensor
|
||||
x = torch.concat(x_list, dim=0).contiguous()
|
||||
|
||||
return x, wq, w_scale, m_sizes
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
|
||||
|
||||
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
|
||||
"""
|
||||
Calculate memory bandwidth based on accessed expert weights.
|
||||
|
||||
Args:
|
||||
m_sizes: Tensor containing batch sizes for each group
|
||||
hidden_size: Hidden dimension size
|
||||
intermediate_size: Intermediate dimension size
|
||||
dtype: Data type of weights
|
||||
|
||||
Returns:
|
||||
Memory size in bytes for accessed expert weights
|
||||
"""
|
||||
# Count non-zero groups (active experts)
|
||||
if hasattr(m_sizes, "cpu"):
|
||||
active_experts = torch.count_nonzero(m_sizes).item()
|
||||
else:
|
||||
active_experts = sum(1 for m in m_sizes if m > 0)
|
||||
|
||||
# Calculate bytes per element based on dtype
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
bytes_per_element = 2
|
||||
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
bytes_per_element = 1
|
||||
elif dtype == torch.float32:
|
||||
bytes_per_element = 4
|
||||
else:
|
||||
# Default to 2 bytes for unknown dtypes
|
||||
bytes_per_element = 2
|
||||
|
||||
# Memory per expert weight matrix
|
||||
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
|
||||
|
||||
# Total memory for active experts
|
||||
total_memory_bytes = active_experts * memory_per_expert
|
||||
|
||||
return total_memory_bytes
|
||||
|
||||
|
||||
def get_benchmark_config(use_fp8_w8a8=False):
|
||||
if use_fp8_w8a8:
|
||||
return {
|
||||
"line_vals": [
|
||||
"fbgemm_triton_grouped_gemm_fp8",
|
||||
"fbgemm_cutlass_f8f8bf16_rowwise",
|
||||
"sglang_grouped_gemm",
|
||||
],
|
||||
"line_names": [
|
||||
"FBGEMM Triton Grouped GEMM FP8",
|
||||
"FBGEMM CUTLASS F8F8BF16 Rowwise",
|
||||
"SGLang Grouped GEMM FP8",
|
||||
],
|
||||
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
|
||||
"line_names": [
|
||||
"FBGEMM Triton Grouped GEMM BF16",
|
||||
"SGLang Grouped GEMM BF16",
|
||||
],
|
||||
"styles": [("blue", "-"), ("green", "-")],
|
||||
}
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
|
||||
):
|
||||
config = get_benchmark_config(use_fp8_w8a8)
|
||||
|
||||
benchmark_config = triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[256, 512, 1024, 2048, 4096],
|
||||
line_arg="provider",
|
||||
line_vals=config["line_vals"],
|
||||
line_names=config["line_names"],
|
||||
styles=config["styles"],
|
||||
ylabel="Bandwidth (GB/s)",
|
||||
plot_name="grouped-gemm-performance",
|
||||
args={},
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(benchmark_config)
|
||||
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"Benchmarking {provider} with batch_size={batch_size}")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_groups = model_config["num_groups"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
intermediate_size = model_config["intermediate_size"]
|
||||
|
||||
if provider == "fbgemm_triton_grouped_gemm_fp8":
|
||||
try:
|
||||
test_data = create_fp8_test_data(
|
||||
batch_size,
|
||||
num_groups,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
backend="triton",
|
||||
)
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
||||
|
||||
# Calculate memory bandwidth
|
||||
memory_bytes = calculate_memory_bandwidth(
|
||||
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
def run_func():
|
||||
return fbgemm_grouped_gemm_fp8_rowwise(
|
||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"FP8 not supported, skipping: {e}")
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
|
||||
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
|
||||
try:
|
||||
test_data = create_fp8_test_data(
|
||||
batch_size,
|
||||
num_groups,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
backend="cutlass",
|
||||
)
|
||||
x, wq, w_scale, m_sizes = test_data
|
||||
|
||||
# Calculate memory bandwidth
|
||||
memory_bytes = calculate_memory_bandwidth(
|
||||
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# Quantize input using triton_quantize_fp8_row
|
||||
xq, x_scale = triton_quantize_fp8_row(x)
|
||||
x_scale = x_scale.view(batch_size, -1)
|
||||
|
||||
def run_func():
|
||||
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
|
||||
xq, wq, x_scale, w_scale, m_sizes
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
|
||||
f"skipping: {e}"
|
||||
)
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
else:
|
||||
test_data = create_test_data(
|
||||
batch_size, num_groups, hidden_size, intermediate_size
|
||||
)
|
||||
(
|
||||
x,
|
||||
w_fbgemm,
|
||||
w_sglang,
|
||||
c_fbgemm,
|
||||
c_sglang,
|
||||
m_sizes,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
) = test_data
|
||||
|
||||
# Calculate memory bandwidth for BF16 operations
|
||||
memory_bytes = calculate_memory_bandwidth(
|
||||
m_sizes, hidden_size, intermediate_size, torch.bfloat16
|
||||
)
|
||||
|
||||
if provider == "fbgemm_triton_grouped_gemm":
|
||||
|
||||
def run_func():
|
||||
return fbgemm_grouped_gemm(
|
||||
x, w_fbgemm, m_sizes, use_fast_accum=True
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def run_func():
|
||||
return sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
try:
|
||||
run_func()
|
||||
except Exception as e:
|
||||
print(f"Error during warmup for {provider}: {e}")
|
||||
return float("inf"), float("inf"), float("inf")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
try:
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
||||
|
||||
# Convert time (ms) to bandwidth (GB/s)
|
||||
# Bandwidth = Memory (bytes) / Time (seconds)
|
||||
# Convert ms to seconds and bytes to GB (1e9)
|
||||
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
|
||||
# min bandwidth = max time, max bandwidth = min time
|
||||
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
|
||||
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
|
||||
|
||||
return gb_per_s, min_gb_per_s, max_gb_per_s
|
||||
except Exception as e:
|
||||
print(f"Error during benchmarking for {provider}: {e}")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
dynamic_benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
)
|
||||
|
||||
|
||||
def verify_correctness(model_config):
|
||||
print("Verifying correctness...")
|
||||
batch_size = 128
|
||||
num_groups = model_config["num_groups"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
intermediate_size = model_config["intermediate_size"]
|
||||
|
||||
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
||||
(
|
||||
x,
|
||||
w_fbgemm,
|
||||
w_sglang,
|
||||
c_fbgemm,
|
||||
c_sglang,
|
||||
m_sizes,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
) = test_data
|
||||
|
||||
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
||||
|
||||
result_sglang = sglang_grouped_gemm(
|
||||
x,
|
||||
w_sglang,
|
||||
c_sglang,
|
||||
num_groups,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
c_dtype=c_sglang.dtype,
|
||||
)
|
||||
|
||||
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
||||
print("✓ BF16 Correctness verification passed!")
|
||||
else:
|
||||
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
||||
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
help="Model name to get configuration from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", type=int, default=1, help="Tensor parallelism size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./benchmark_grouped_gemm/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify-correctness",
|
||||
action="store_true",
|
||||
help="Verify correctness before benchmarking",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
except Exception as e:
|
||||
print(f"Failed to get model config: {e}")
|
||||
print("Using default configuration...")
|
||||
model_config = {
|
||||
"num_groups": 8,
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
print("Running benchmark with:")
|
||||
print(f" num_groups: {model_config['num_groups']}")
|
||||
print(f" hidden_size: {model_config['hidden_size']}")
|
||||
print(f" intermediate_size: {model_config['intermediate_size']}")
|
||||
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
||||
|
||||
if args.verify_correctness:
|
||||
if not verify_correctness(model_config):
|
||||
print("Correctness verification failed. Exiting...")
|
||||
return
|
||||
|
||||
try:
|
||||
run_benchmark(
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
save_path=args.save_path,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
102
benchmark/kernels/flashinfer_allreduce_fusion/README.md
Normal file
102
benchmark/kernels/flashinfer_allreduce_fusion/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# FlashInfer Fused AllReduce + RMSNorm Benchmark
|
||||
|
||||
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
|
||||
|
||||
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
|
||||
|
||||
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
|
||||
|
||||
## Feature Overview
|
||||
|
||||
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
|
||||
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
|
||||
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
|
||||
- Optionally compare FP8/FP4 quantized fused paths with standard paths
|
||||
- Use CUDA Graph capture and batch replay to reduce measurement noise
|
||||
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
|
||||
- Optionally export results in Markdown format
|
||||
|
||||
## Runtime Environment and Prerequisites
|
||||
|
||||
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
|
||||
- Properly install/compile sglang along with sgl-kernel and custom operators
|
||||
|
||||
## Quick Start (Command Examples)
|
||||
|
||||
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
|
||||
|
||||
- Regular paths only (no quantization):
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- FP8 quantization paths only:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- FP4 quantization paths only:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- Larger hidden dimensions:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
## Parameter Description
|
||||
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
|
||||
- `--hidden-dim`: Hidden dimension (default: 8192)
|
||||
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
|
||||
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
|
||||
- Mutually exclusive quantization options:
|
||||
- `--no-quant`: No quantization testing
|
||||
- `--quant-fp8`: Only FP8 quantization testing
|
||||
- `--quant-fp4`: Only FP4 quantization testing
|
||||
- `--quant-all`: Test all (default)
|
||||
- FlashInfer related:
|
||||
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
|
||||
- Runtime configuration:
|
||||
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
|
||||
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
|
||||
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
|
||||
|
||||
## Output Example
|
||||
|
||||
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
|
||||
```
|
||||
================================================================================
|
||||
Results: seq_len=1024, hidden_dim=1024
|
||||
dtype=torch.bfloat16, residual=yes, quant_mode=none
|
||||
================================================================================
|
||||
Operation Time (ms) Speedup
|
||||
--------------------------------------------------------------------------------
|
||||
standard_allreduce_rmsnorm 0.024 0.98x
|
||||
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
|
||||
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
|
||||
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
|
||||
```
|
||||
|
||||
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
|
||||
|
||||
## Important Notes and Recommendations
|
||||
|
||||
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
|
||||
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
|
||||
- FlashInfer:
|
||||
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
|
||||
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
|
||||
- FP8/FP4:
|
||||
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
|
||||
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
|
||||
- CUDA Graph:
|
||||
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.
|
||||
File diff suppressed because it is too large
Load Diff
76
benchmark/kernels/fused_moe_triton/README.md
Normal file
76
benchmark/kernels/fused_moe_triton/README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
## Tuning Triton MoE Kernels
|
||||
|
||||
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
||||
|
||||
### Tuning Tool
|
||||
|
||||
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Tune Mixtral-8x7B with default settings
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen2-57B with FP8 and TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--tp-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen3-235B-A22B-FP8 and TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen3-235B-A22B-FP8 \
|
||||
--tp-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune DeepSeek-V3 with FP8 and TP=8
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model meituan/DeepSeek-R1-Channel-INT8 \
|
||||
--tp-size 16 \
|
||||
--dtype int8_w8a8 \
|
||||
--tune
|
||||
```
|
||||
|
||||
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
|
||||
|
||||
### Performance Comparison Tool
|
||||
|
||||
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Compare with default settings (Mixtral model)
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
||||
|
||||
# Compare with FP8 mode for Qwen2-57B
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--use-fp8-w8a8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
```
|
||||
|
||||
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||
|
||||
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
||||
|
||||
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
|
||||
@@ -0,0 +1,292 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
]:
|
||||
E = (
|
||||
config.n_routed_experts + 1
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_moe_triton_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
):
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
topk_op.use_triton_kernels = True
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
return triton_kernel_moe_forward(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
triton_topk_output,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
topk_output = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(
|
||||
batch_size,
|
||||
provider,
|
||||
model_config,
|
||||
use_fp8_w8a8=False,
|
||||
use_cuda_graph: bool = False,
|
||||
):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
w1_tri = w1_tri.transpose(-2, -1).contiguous()
|
||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
if provider == "sglang_fused_moe_triton_v340":
|
||||
api_func = fused_moe_triton_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1_tri,
|
||||
"w2": w2_tri,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
}
|
||||
else:
|
||||
api_func = fused_moe_sglang_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
"use_fp8_w8a8": use_fp8_w8a8,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if use_cuda_graph:
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
bench_lambda = lambda: graph.replay()
|
||||
else:
|
||||
bench_lambda = lambda: api_func(**api_kwargs)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/sglang_fused_moe/",
|
||||
)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
202
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
202
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import do_bench
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _moe_sum_reduce_kernel(
|
||||
input_ptr,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_ptr,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
token_num: int,
|
||||
topk_num: int,
|
||||
hidden_dim: int,
|
||||
routed_scaling_factor: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
):
|
||||
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
||||
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
||||
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
||||
|
||||
token_block_id = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
|
||||
mask_token = offs_token < token_num
|
||||
mask_dim = offs_dim < hidden_dim
|
||||
|
||||
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
||||
|
||||
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||
tile = tl.load(
|
||||
base_ptrs + i * input_stride_1,
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tile.to(tl.float32)
|
||||
accumulator *= routed_scaling_factor
|
||||
|
||||
# -------- Write back --------
|
||||
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
||||
tl.store(
|
||||
store_ptrs,
|
||||
accumulator.to(input_ptr.dtype.element_ty),
|
||||
mask=mask_token[:, None] & mask_dim[None, :],
|
||||
)
|
||||
|
||||
|
||||
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||
def moe_sum_reduce(
|
||||
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||
):
|
||||
assert input.is_contiguous()
|
||||
assert output.is_contiguous()
|
||||
|
||||
token_num, topk_num, hidden_dim = input.shape
|
||||
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
||||
|
||||
BLOCK_M = 1
|
||||
BLOCK_DIM = 2048
|
||||
NUM_STAGE = 1
|
||||
num_warps = 16
|
||||
|
||||
grid = (
|
||||
triton.cdiv(token_num, BLOCK_M),
|
||||
triton.cdiv(hidden_dim, BLOCK_DIM),
|
||||
)
|
||||
|
||||
_moe_sum_reduce_kernel[grid](
|
||||
input,
|
||||
*input.stride(),
|
||||
output,
|
||||
*output.stride(),
|
||||
token_num=token_num,
|
||||
topk_num=topk_num,
|
||||
hidden_dim=hidden_dim,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_DIM=BLOCK_DIM,
|
||||
NUM_STAGE=NUM_STAGE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def compute_sum_scaled_baseline(
|
||||
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||
) -> torch.Tensor:
|
||||
torch.sum(x, dim=1, out=out)
|
||||
out.mul_(routed_scaling_factor)
|
||||
return out
|
||||
|
||||
|
||||
@torch.compile
|
||||
def compute_sum_scaled_compiled(
|
||||
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||
) -> torch.Tensor:
|
||||
torch.sum(x * routed_scaling_factor, dim=1, out=out)
|
||||
return out
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
num_tokens_range = [2**i for i in range(0, 13)]
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=num_tokens_range,
|
||||
line_arg="version",
|
||||
line_vals=["baseline", "compiled", "triton"],
|
||||
line_names=["Original", "TorchCompile", "TritonKernel"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="sum_scaled_performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, version):
|
||||
topk = 9
|
||||
hidden_size = 4096
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3
|
||||
|
||||
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
if version == "baseline":
|
||||
compute_sum_scaled_baseline(x, out, scaling_factor)
|
||||
elif version == "compiled":
|
||||
compute_sum_scaled_compiled(x, out, scaling_factor)
|
||||
else:
|
||||
moe_sum_reduce(x, out, scaling_factor)
|
||||
|
||||
# Benchmark
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if version == "baseline":
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif version == "compiled":
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = do_bench(
|
||||
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def verify_correctness(num_tokens=1024):
|
||||
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
|
||||
scaling_factor = 0.3
|
||||
|
||||
out_baseline = torch.empty_like(x[:, 0])
|
||||
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
|
||||
|
||||
out_compiled = torch.empty_like(out_baseline)
|
||||
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
|
||||
|
||||
out_triton = torch.empty_like(out_baseline)
|
||||
moe_sum_reduce(x, out_triton, scaling_factor)
|
||||
|
||||
if torch.allclose(
|
||||
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
print(
|
||||
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
|
||||
)
|
||||
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running correctness verification...")
|
||||
verify_correctness()
|
||||
|
||||
print("\nRunning performance benchmark...")
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
# save_path="./configs/benchmark_ops/sum_scaled/"
|
||||
)
|
||||
@@ -0,0 +1,305 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch.nn import functional as F
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_triton,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
M, _ = hidden_states.shape
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
) -> torch.Tensor:
|
||||
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
|
||||
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=x,
|
||||
gating_output=input_gating,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
w13_weights = w1[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = w2[topk_ids]
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
|
||||
def fused_moe_torch_compile(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_triton(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 5)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
line_names=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
set_torch_compile_config()
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_torch_compile
|
||||
if provider == "fused_moe_torch_compile"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fused_moe_torch_compile/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,349 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import vllm
|
||||
from transformers import AutoConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
]:
|
||||
E = (
|
||||
config.n_routed_experts + 1
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
vllm_version_num = (
|
||||
vllm.__version_tuple__[0] * 100
|
||||
+ vllm.__version_tuple__[1] * 10
|
||||
+ vllm.__version_tuple__[2]
|
||||
)
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
assert (
|
||||
vllm_version_num >= 66
|
||||
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_moe_vllm_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
if block_shape is not None:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 513)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
|
||||
if block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_vllm_api
|
||||
if provider == "vllm_fused_moe_triton"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
599
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
599
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_dtype_str,
|
||||
get_config_file_name,
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16 or use_int8_w8a8:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if use_int8_w8a8 and block_shape is None:
|
||||
w1_scale = torch.randn(
|
||||
num_experts, shard_intermediate_size, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
|
||||
elif block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
topk_output = select_experts(x, input_gating, topk_config)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating = gating_output[i]
|
||||
new_topk_output = select_experts(x, input_gating, topk_config)
|
||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||
|
||||
def run():
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [1, 2, 4, 8]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
configs: List[BenchmarkConfig] = []
|
||||
if _is_hip:
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
False,
|
||||
block_shape,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
)
|
||||
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = (
|
||||
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts + (
|
||||
0 if args.disable_shared_experts_fusion else 1
|
||||
)
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
search_space = [
|
||||
config
|
||||
for config in search_space
|
||||
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.perf_counter()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(
|
||||
best_configs,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
print(f"Kernel time: {kernel_time:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,576 @@
|
||||
import itertools
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _decode_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
KV,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
kv_offset = off_bh * d * e
|
||||
|
||||
s = tl.load(S + off_h)
|
||||
ratio = tl.exp(-s)
|
||||
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def lightning_attn_decode(q, k, v, kv, s):
|
||||
"""Triton implementation of Lightning Attention decode operation"""
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
kv_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b=b,
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
return o, kv_out
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
class MiniMaxText01LightningAttention(nn.Module):
|
||||
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = type("Config", (), kwargs)
|
||||
|
||||
bias = False
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.head_dim * self.num_heads, self.hidden_size, bias=bias
|
||||
)
|
||||
self.act = get_activation_fn(config.hidden_act)
|
||||
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
|
||||
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
|
||||
)
|
||||
self.output_gate = nn.Linear(
|
||||
self.hidden_size, self.head_dim * self.num_heads, bias=bias
|
||||
)
|
||||
|
||||
# for inference only
|
||||
self.offset = 0
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
|
||||
output_attentions: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
slope_rate: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if (not self.training) and (not do_eval):
|
||||
return self.inference(
|
||||
hidden_states,
|
||||
attn_mask,
|
||||
output_attentions,
|
||||
past_key_value,
|
||||
use_cache,
|
||||
slope_rate,
|
||||
)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
attn_mask: Optional[torch.Tensor] = None, # (b, n)
|
||||
output_attentions: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
|
||||
):
|
||||
# x: b n d
|
||||
b, n, d = x.shape
|
||||
# linear map
|
||||
qkv = self.act(self.qkv_proj(x))
|
||||
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
|
||||
q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
|
||||
k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
|
||||
v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e]
|
||||
|
||||
self.offset += 1
|
||||
ratio = torch.exp(-slope_rate) # [h, 1, 1]
|
||||
|
||||
# decode mode
|
||||
kv = past_key_value # [b, h, d, e]
|
||||
output = []
|
||||
for i in range(n):
|
||||
# kv: [b, h, d, e]
|
||||
# ratio: [h, 1, 1]
|
||||
# k: [b, h, n, d]
|
||||
# v: [b, h, n, e]
|
||||
# k[:, :, i : i + 1]: [b, h, 1, d]
|
||||
# v[:, :, i : i + 1]: [b, h, 1, e]
|
||||
# ratio * kv: [b, h, d, e]
|
||||
# torch.einsum(
|
||||
# "... n d, ... n e -> ... d e",
|
||||
# k[:, :, i : i + 1],
|
||||
# v[:, :, i : i + 1],
|
||||
# )
|
||||
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
|
||||
kv = ratio * kv + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
# q[:, :, i : i + 1]: [b, h, 1, d]
|
||||
# kv.to(q.dtype): [b, h, d, e]
|
||||
# torch.einsum(
|
||||
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||
# )
|
||||
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
# reshape
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
# normalize
|
||||
output = self.norm(output)
|
||||
# gate
|
||||
output = F.sigmoid(self.output_gate(x)) * output
|
||||
# outproj
|
||||
output = self.out_proj(output)
|
||||
|
||||
attn_weights = None
|
||||
|
||||
return output, attn_weights, kv
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
elif activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "elu":
|
||||
return F.elu
|
||||
elif activation == "sigmoid":
|
||||
return F.sigmoid
|
||||
elif activation == "exp":
|
||||
|
||||
def f(x):
|
||||
with torch.no_grad():
|
||||
x_max = torch.max(x, dim=-1, keepdims=True).values
|
||||
y = torch.exp(x - x_max)
|
||||
return y
|
||||
|
||||
return f
|
||||
elif activation == "leak":
|
||||
return F.leaky_relu
|
||||
elif activation == "1+elu":
|
||||
|
||||
def f(x):
|
||||
return 1 + F.elu(x)
|
||||
|
||||
return f
|
||||
elif activation == "2+elu":
|
||||
|
||||
def f(x):
|
||||
return 2 + F.elu(x)
|
||||
|
||||
return f
|
||||
elif activation == "silu" or activation == "swish":
|
||||
return F.silu
|
||||
elif activation == "sine":
|
||||
return torch.sin
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
class MiniMaxText01RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def test_lightning_attention_implementations(model_params):
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size = 64
|
||||
seq_len = 1
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||
|
||||
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
|
||||
|
||||
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
|
||||
model_attn.eval()
|
||||
|
||||
d = model_params["head_dim"]
|
||||
past_kv = torch.randn(
|
||||
batch_size,
|
||||
model_params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
model_output, _, new_kv = model_attn.inference(
|
||||
hidden_states,
|
||||
attn_mask=attention_mask,
|
||||
slope_rate=slope_rate,
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
past_kv = past_kv.contiguous()
|
||||
slope_rate = slope_rate.contiguous()
|
||||
|
||||
# Test Triton implementation
|
||||
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||
triton_output = triton_output.transpose(1, 2).contiguous()
|
||||
triton_output = triton_output.view(batch_size, seq_len, -1)
|
||||
triton_output = model_attn.norm(triton_output)
|
||||
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
||||
triton_output = model_attn.out_proj(triton_output)
|
||||
|
||||
# Test SGL implementation
|
||||
sgl_output = torch.empty_like(v)
|
||||
sgl_new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
|
||||
|
||||
sgl_output = sgl_output.transpose(1, 2).contiguous()
|
||||
sgl_output = sgl_output.view(batch_size, seq_len, -1)
|
||||
sgl_output = model_attn.norm(sgl_output)
|
||||
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
|
||||
sgl_output = model_attn.out_proj(sgl_output)
|
||||
|
||||
# Verify Triton implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
triton_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Triton lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
triton_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Triton lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
# Verify SGL implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
sgl_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
sgl_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
print("✅ All implementations match")
|
||||
|
||||
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
|
||||
n_attention_heads, 1, 1
|
||||
)
|
||||
return slopes
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_size_range = [i for i in range(1, 33)] # max 32
|
||||
seq_length_range = [1] # decode mode sequence length is fixed to 1
|
||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["Original", "Triton", "SGL"],
|
||||
line_names=[
|
||||
"Original PyTorch Implementation",
|
||||
"Triton Implementation",
|
||||
"SGL Implementation",
|
||||
],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
params = {
|
||||
"hidden_size": 6144,
|
||||
"num_attention_heads": 64,
|
||||
"head_dim": 96,
|
||||
"hidden_act": "gelu",
|
||||
}
|
||||
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||
|
||||
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
|
||||
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
|
||||
model_attn.eval()
|
||||
|
||||
d = params["head_dim"]
|
||||
past_kv = torch.randn(
|
||||
batch_size,
|
||||
params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
device=device,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "Original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: model_attn.inference(
|
||||
hidden_states,
|
||||
attn_mask=attention_mask,
|
||||
slope_rate=slope_rate,
|
||||
past_key_value=past_kv,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "Triton":
|
||||
|
||||
def run_triton():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
output = output.view(batch_size, seq_len, -1)
|
||||
output = model_attn.norm(output)
|
||||
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||
return model_attn.out_proj(output)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_triton,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # SGL
|
||||
|
||||
def run_sgl():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
|
||||
output = torch.empty_like(v)
|
||||
new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(
|
||||
q, k, v, past_kv, slope_rate, output, new_kv
|
||||
)
|
||||
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
output = output.view(batch_size, seq_len, -1)
|
||||
output = model_attn.norm(output)
|
||||
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||
return model_attn.out_proj(output)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_sgl,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_decode/",
|
||||
help="Path to save lightning attention decode benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = {
|
||||
"hidden_size": 6144,
|
||||
"num_attention_heads": 64,
|
||||
"head_dim": 96,
|
||||
"hidden_act": "silu",
|
||||
}
|
||||
# Run correctness test first
|
||||
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
|
||||
test_lightning_attention_implementations(params)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@@ -0,0 +1,603 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
S, # log lambda
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK: tl.constexpr,
|
||||
BLOCK_MODEL: tl.constexpr,
|
||||
):
|
||||
##### get offset
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
off_e = tl.program_id(1)
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
# channel offset
|
||||
e_offset = off_e * BLOCK_MODEL
|
||||
|
||||
##### get block ptr
|
||||
Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
|
||||
K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
|
||||
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
|
||||
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
|
||||
S_block_ptr = S + off_h
|
||||
|
||||
##### init diag decay(Lambda); q, k decay; kv
|
||||
s = tl.load(S_block_ptr)
|
||||
# q, k decay
|
||||
off_block = tl.arange(
|
||||
0, BLOCK
|
||||
) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
|
||||
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
|
||||
k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
|
||||
block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
|
||||
# diag decay
|
||||
index = off_block[:, None] - off_block[None, :]
|
||||
s_index = s * index
|
||||
s_index = tl.where(index >= 0, -s_index, float("-inf"))
|
||||
diag_decay = tl.exp(s_index)
|
||||
kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)
|
||||
|
||||
##### compute
|
||||
for i in range(NUM_BLOCK):
|
||||
# load
|
||||
q = tl.load(
|
||||
Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
|
||||
).to(tl.float32)
|
||||
k_trans = tl.load(
|
||||
K_trans_block_ptr + off_block[None, :] * d,
|
||||
mask=off_block[None, :] < n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
v = tl.load(
|
||||
V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
# compute
|
||||
qk = tl.dot(q, k_trans) * diag_decay
|
||||
o_intra = tl.dot(qk, v)
|
||||
o_inter = tl.dot(q, kv) * q_decay
|
||||
o = o_intra + o_inter
|
||||
|
||||
# save and update
|
||||
tl.store(
|
||||
O_block_ptr + off_block[:, None] * e,
|
||||
o.to(O_block_ptr.dtype.element_ty),
|
||||
mask=off_block[:, None] < n,
|
||||
)
|
||||
kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
|
||||
off_block += BLOCK
|
||||
|
||||
|
||||
def lightning_attn2(q, k, v, s):
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
s = s.contiguous()
|
||||
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
|
||||
# Pad d to next power of 2
|
||||
d_padded = next_power_of_2(d)
|
||||
if d_padded != d:
|
||||
q_padded = F.pad(q, (0, d_padded - d))
|
||||
k_padded = F.pad(k, (0, d_padded - d))
|
||||
else:
|
||||
q_padded = q
|
||||
k_padded = k
|
||||
|
||||
# Pad e to next power of 2
|
||||
e_padded = next_power_of_2(e)
|
||||
if e_padded != e:
|
||||
v_padded = F.pad(v, (0, e_padded - e))
|
||||
else:
|
||||
v_padded = v
|
||||
|
||||
o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)
|
||||
|
||||
BLOCK = 64
|
||||
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
|
||||
# parallel over channel
|
||||
BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
|
||||
grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d_padded,
|
||||
e_padded,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
BLOCK_MODEL=BLOCK_MODEL,
|
||||
)
|
||||
|
||||
# Remove padding from output
|
||||
if e_padded != e:
|
||||
o = o_padded[..., :e]
|
||||
else:
|
||||
o = o_padded
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def is_support(dim):
|
||||
return 16 % dim
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
def lightning_attn_func(q, k, v, s):
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert is_support(d) and is_support(e)
|
||||
|
||||
# pad v's feature dim to power of 2
|
||||
e_pad = next_power_of_2(e)
|
||||
need_pad = e_pad != e
|
||||
if need_pad:
|
||||
v = F.pad(v, (0, e_pad - e))
|
||||
|
||||
if d > 128:
|
||||
# split over head
|
||||
if 64 % d:
|
||||
m = 64
|
||||
elif 32 % d:
|
||||
m = 32
|
||||
elif 16 % d:
|
||||
m = 16
|
||||
arr = [m * i for i in range(d // m + 1)]
|
||||
if arr[-1] != d:
|
||||
arr.append(d)
|
||||
n = len(arr)
|
||||
o = 0
|
||||
for i in range(n - 1):
|
||||
start = arr[i]
|
||||
end = arr[i + 1]
|
||||
q1 = q[..., start:end]
|
||||
k1 = k[..., start:end]
|
||||
o += lightning_attn2(q1, k1, v, s)
|
||||
else:
|
||||
o = lightning_attn2(q, k, v, s)
|
||||
|
||||
if need_pad:
|
||||
o = o[:, :, :, :e]
|
||||
|
||||
return o
|
||||
|
||||
|
||||
debug = eval(os.environ.get("debug", default="False"))
|
||||
|
||||
BLOCK = 256
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
|
||||
class MiniMaxText01RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
|
||||
def get_activation_fn(activation):
|
||||
if debug:
|
||||
logger.info(f"activation: {activation}")
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
elif activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "elu":
|
||||
return F.elu
|
||||
elif activation == "sigmoid":
|
||||
return F.sigmoid
|
||||
elif activation == "exp":
|
||||
|
||||
def f(x):
|
||||
with torch.no_grad():
|
||||
x_max = torch.max(x, dim=-1, keepdims=True).values
|
||||
y = torch.exp(x - x_max)
|
||||
|
||||
return y
|
||||
|
||||
return f
|
||||
elif activation == "leak":
|
||||
return F.leaky_relu
|
||||
elif activation == "1+elu":
|
||||
|
||||
def f(x):
|
||||
return 1 + F.elu(x)
|
||||
|
||||
return f
|
||||
elif activation == "2+elu":
|
||||
|
||||
def f(x):
|
||||
return 2 + F.elu(x)
|
||||
|
||||
return f
|
||||
elif activation == "silu" or activation == "swish":
|
||||
return F.silu
|
||||
elif activation == "sine":
|
||||
return torch.sin
|
||||
else:
|
||||
logger.info(f"activation: does not support {activation}, use Identity!!!")
|
||||
return lambda x: x
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
|
||||
class MiniMaxText01LightningAttention(nn.Module):
|
||||
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = type("Config", (), kwargs)
|
||||
|
||||
bias = False
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.head_dim * self.num_heads, self.hidden_size, bias=bias
|
||||
)
|
||||
self.act = get_activation_fn(config.hidden_act)
|
||||
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
|
||||
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
|
||||
)
|
||||
self.output_gate = nn.Linear(
|
||||
self.hidden_size, self.head_dim * self.num_heads, bias=bias
|
||||
)
|
||||
|
||||
# for inference only
|
||||
self.offset = 0
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
|
||||
output_attentions: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
slope_rate: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if (not self.training) and (not do_eval):
|
||||
return self.inference(
|
||||
hidden_states,
|
||||
attn_mask,
|
||||
output_attentions,
|
||||
past_key_value,
|
||||
use_cache,
|
||||
slope_rate,
|
||||
)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
attn_mask: Optional[torch.Tensor] = None, # (b, n)
|
||||
output_attentions: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
|
||||
):
|
||||
# x: b n d
|
||||
b, n, d = x.shape
|
||||
# linear map
|
||||
qkv = self.act(self.qkv_proj(x))
|
||||
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
if past_key_value is None:
|
||||
self.offset = q.shape[-2]
|
||||
else:
|
||||
self.offset += 1
|
||||
|
||||
# for align with metaseq
|
||||
ratio = torch.exp(-slope_rate)
|
||||
|
||||
# only use for the first time
|
||||
if past_key_value is None:
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
if attn_mask is not None:
|
||||
v = v.masked_fill(
|
||||
(1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
|
||||
)
|
||||
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
# other
|
||||
array = torch.arange(BLOCK).to(q) + 1
|
||||
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
|
||||
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
|
||||
index = array[:, None] - array[None, :]
|
||||
s_index = (
|
||||
slope_rate
|
||||
* index[
|
||||
None,
|
||||
None,
|
||||
]
|
||||
)
|
||||
s_index = torch.where(index >= 0, -s_index, float("-inf"))
|
||||
diag_decay = torch.exp(s_index)
|
||||
|
||||
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
|
||||
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
||||
for i in range(NUM_BLOCK):
|
||||
si = i * BLOCK
|
||||
ei = min(si + BLOCK, n)
|
||||
m = ei - si
|
||||
qi = q[:, :, si:ei].contiguous()
|
||||
ki = k[:, :, si:ei].contiguous()
|
||||
vi = v[:, :, si:ei].contiguous()
|
||||
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
|
||||
|
||||
# diag
|
||||
qk = (
|
||||
torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
|
||||
* diag_decay[:, :, :m, :m]
|
||||
)
|
||||
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
|
||||
block_decay = torch.exp(-slope_rate * m)
|
||||
output[:, :, si:ei] = qkv_none_diag + qkv_diag
|
||||
kv = block_decay * kv + torch.matmul(
|
||||
(ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
|
||||
)
|
||||
|
||||
else:
|
||||
kv = past_key_value
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
# reshape
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
# normalize
|
||||
output = self.norm(output)
|
||||
# gate
|
||||
output = F.sigmoid(self.output_gate(x)) * output
|
||||
# outproj
|
||||
output = self.out_proj(output)
|
||||
|
||||
attn_weights = None
|
||||
|
||||
return output, attn_weights, kv
|
||||
|
||||
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(
|
||||
n
|
||||
) # In the paper, we only train models that have 2^a heads for some a. This function has
|
||||
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
||||
closest_power_of_2 = 2 ** math.floor(
|
||||
math.log2(n)
|
||||
) # when the number of heads is not a power of 2, we use this workaround.
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
# h, 1, 1
|
||||
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
|
||||
n_attention_heads, 1, 1
|
||||
)
|
||||
|
||||
return slopes
|
||||
|
||||
|
||||
def test_lightning_attention_implementations(model_params):
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1024
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||
|
||||
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
|
||||
|
||||
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
|
||||
model_attn.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
model_output, _, _ = model_attn.inference(
|
||||
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
|
||||
)
|
||||
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
lib_output = lightning_attn_func(q, k, v, slope_rate)
|
||||
lib_output = lib_output.transpose(1, 2).contiguous()
|
||||
lib_output = lib_output.view(batch_size, seq_len, -1)
|
||||
lib_output = model_attn.norm(lib_output)
|
||||
lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
|
||||
lib_output = model_attn.out_proj(lib_output)
|
||||
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
lib_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Lightning attention implementations produce different results",
|
||||
)
|
||||
|
||||
print("✅ Two implementations match")
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_size_range = [2**i for i in range(0, 7)] # max 64
|
||||
seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096
|
||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["MiniMax-Text-01", "OpenNLPLab"],
|
||||
line_names=[
|
||||
"MiniMax-Text-01 Model Implementation",
|
||||
"OpenNLPLab Library Implementation",
|
||||
],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-prefill-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
params = {
|
||||
"hidden_size": 6144,
|
||||
"num_attention_heads": 64,
|
||||
"head_dim": 96,
|
||||
"hidden_act": "gelu",
|
||||
}
|
||||
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||
|
||||
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
|
||||
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
|
||||
model_attn.eval()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "MiniMax-Text-01":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: model_attn.inference(
|
||||
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
|
||||
def run_lib():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
lib_output = lightning_attn_func(q, k, v, slope_rate)
|
||||
lib_output = lib_output.transpose(1, 2).contiguous()
|
||||
lib_output = lib_output.view(batch_size, seq_len, -1)
|
||||
lib_output = model_attn.norm(lib_output)
|
||||
lib_output = (
|
||||
torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
|
||||
)
|
||||
return model_attn.out_proj(lib_output)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_lib,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_prefill/",
|
||||
help="Path to save lightning attention prefill benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test first
|
||||
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
|
||||
params = {
|
||||
"hidden_size": 6144,
|
||||
"num_attention_heads": 64,
|
||||
"head_dim": 96,
|
||||
"hidden_act": "silu",
|
||||
}
|
||||
test_lightning_attention_implementations(params)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
133
benchmark/kernels/quantization/bench_fp4_quant.py
Normal file
133
benchmark/kernels/quantization/bench_fp4_quant.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
|
||||
from sgl_kernel.elementwise import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
|
||||
def _test_accuracy_once(E, M, K, input_dtype, device):
|
||||
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.full((E,), M, dtype=torch.int32, device=device)
|
||||
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
|
||||
out1, blk_scales1 = scaled_fp4_grouped_quant(
|
||||
silu_and_mul(x),
|
||||
glb_scales,
|
||||
masks,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out1)
|
||||
torch.testing.assert_close(blk_scales, blk_scales1)
|
||||
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
|
||||
|
||||
|
||||
NUM_RANKS = 48
|
||||
M_PER_RANKs = [128, 256, 512, 1024]
|
||||
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
|
||||
Ks = [2048, 4096, 7168]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "K"],
|
||||
x_vals=list(itertools.product(Ms, Ks)),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="fp4 quant",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, K, provider):
|
||||
E = 6
|
||||
device = "cuda"
|
||||
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
|
||||
fp8_out = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
scale_block_size = 128
|
||||
fp8_scales = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "triton_fp8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: silu_and_mul_masked_post_quant_fwd(
|
||||
x,
|
||||
fp8_out,
|
||||
fp8_scales,
|
||||
scale_block_size,
|
||||
masks,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_unfused_fp4":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: scaled_fp4_grouped_quant(
|
||||
silu_and_mul(x),
|
||||
glb_scales,
|
||||
masks,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_fused_fp4":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: silu_and_mul_scaled_fp4_grouped_quant(
|
||||
x,
|
||||
glb_scales,
|
||||
masks,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
E = 6
|
||||
N_RANKS = 48
|
||||
Ms = [128, 256, 512, 1024]
|
||||
Ks = [2048, 4096, 7168]
|
||||
input_dtype = torch.bfloat16
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./bench_fp4_quant_res",
|
||||
help="Path to save fp4 quant benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_accuracy()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||
94
benchmark/kernels/quantization/bench_int8_quant.py
Normal file
94
benchmark/kernels/quantization/bench_int8_quant.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
|
||||
|
||||
@torch.compile(backend="inductor")
|
||||
def torch_int8_quant(x):
|
||||
int8_max = torch.iinfo(torch.int8).max
|
||||
|
||||
abs_max = x.abs().max(dim=-1, keepdim=True).values
|
||||
scales = abs_max.to(torch.float32) / float(int8_max)
|
||||
|
||||
q_x = (x / scales).round().to(torch.int8)
|
||||
|
||||
return q_x, scales
|
||||
|
||||
|
||||
def _test_accuracy_once(M, K, input_dtype, device):
|
||||
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
|
||||
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
|
||||
out1, scales1 = per_token_quant_int8(x)
|
||||
out2, scales2 = torch_int8_quant(x)
|
||||
torch.testing.assert_close(out, out2, atol=1, rtol=0)
|
||||
torch.testing.assert_close(out, out1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(scales, scales2)
|
||||
torch.testing.assert_close(scales1, scales2)
|
||||
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
Ms = [1, 13, 128, 1024, 2048, 4096]
|
||||
Ks = [512, 1024, 2048, 8192]
|
||||
input_dtypes = [torch.float16, torch.bfloat16]
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
for input_dtype in input_dtypes:
|
||||
_test_accuracy_once(M, K, input_dtype, "cuda")
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm op", "triton", "torch.compile"],
|
||||
line_names=["vllm op", "triton", "torch.compile"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="int8 per token quant",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
M, K = batch_size, 16384
|
||||
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "vllm op":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_int8_quant(x, symmetric=True),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: per_token_quant_int8(x),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "torch.compile":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_int8_quant(x),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./bench_int8_quant_res",
|
||||
help="Path to save int8 quant benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_accuracy()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||
474
benchmark/kernels/quantization/tuning_block_wise_kernel.py
Normal file
474
benchmark/kernels/quantization/tuning_block_wise_kernel.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
_w8a8_block_fp8_matmul_unrolledx4,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
config: Dict[str, Any],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise quantization.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
else:
|
||||
kernel = _w8a8_block_int8_matmul
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def get_rocm_configs_compute_bound():
|
||||
configs = []
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
if _is_hip:
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
):
|
||||
def run():
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# JIT complication & warmup
|
||||
for _ in range(5):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
run()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
return avg
|
||||
|
||||
|
||||
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
if input_type == "fp8":
|
||||
fp8_info = torch.finfo(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype,
|
||||
num_iters=10,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def save_configs(
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
configs,
|
||||
save_path,
|
||||
input_type="fp8",
|
||||
) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
device_name = get_device_name().replace(" ", "_")
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
|
||||
with open(config_file_path, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_available_gpu_count():
|
||||
"""Get the number of available GPUs."""
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def tune_on_gpu(args_dict):
|
||||
"""Run tuning on a specific GPU."""
|
||||
gpu_id = args_dict["gpu_id"]
|
||||
batch_sizes = args_dict["batch_sizes"]
|
||||
weight_shapes = args_dict["weight_shapes"]
|
||||
args = args_dict["args"]
|
||||
|
||||
torch.cuda.set_device(gpu_id)
|
||||
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||
|
||||
block_n = args.block_n
|
||||
block_k = args.block_k
|
||||
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||
save_path = args.save_path
|
||||
input_type = args.input_type
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
results = {}
|
||||
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||
N, K = shape[0], shape[1]
|
||||
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||
benchmark_results = [
|
||||
tune(
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
[block_n, block_k],
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
)
|
||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||
"""Distribute batch sizes across available GPUs."""
|
||||
batches_per_gpu = []
|
||||
for i in range(num_gpus):
|
||||
start_idx = i * len(batch_sizes) // num_gpus
|
||||
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||
return batches_per_gpu
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
|
||||
num_gpus = get_available_gpu_count()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("No GPU available for tuning")
|
||||
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||
|
||||
torch.cuda.init()
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
num_gpus = 1 # If only one batch size, use only one GPU
|
||||
|
||||
weight_shapes = get_weight_shapes(args.tp_size)
|
||||
|
||||
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||
|
||||
process_args = []
|
||||
for gpu_id in range(num_gpus):
|
||||
process_args.append(
|
||||
{
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
}
|
||||
)
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(num_gpus) as pool:
|
||||
pool.map(tune_on_gpu, process_args)
|
||||
|
||||
print("Multi-GPU tuning completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
choices=["float32", "float16", "bfloat16", "half"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument("--block-n", type=int, default=128)
|
||||
parser.add_argument("--block-k", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument(
|
||||
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
230
benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
Normal file
230
benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from torch import nn
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
def rmsnorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||
naive_norm.weight = nn.Parameter(weight)
|
||||
naive_norm = naive_norm.to(x.device)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
output = naive_norm(x, residual)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_flashinfer(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
output = rmsnorm(x, weight, eps)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.rms_norm(out, x, weight, eps)
|
||||
output = out
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
output_naive = rmsnorm_naive(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_flashinfer = rmsnorm_flashinfer(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_vllm = rmsnorm_vllm(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
|
||||
if use_residual:
|
||||
output_naive = output_naive[0]
|
||||
output_flashinfer = output_flashinfer[0]
|
||||
output_vllm = output_vllm[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
head_num_range = [32, 48]
|
||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark(use_residual):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["head_num", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(head_num, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "huggingface":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/rmsnorm/",
|
||||
help="Path to save rmsnorm benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
|
||||
)
|
||||
|
||||
# Get the benchmark function with proper use_residual setting
|
||||
benchmark = get_benchmark(args.use_residual)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_torch(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.where(
|
||||
prefix_lens_tensor > 0,
|
||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||
torch.full_like(prefix_lens_tensor, -1),
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_last_loc_kernel(
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
||||
mask = offset < num_tokens
|
||||
|
||||
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
||||
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
||||
|
||||
token_mask = prefix_lens > 0
|
||||
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
||||
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
||||
|
||||
tl.store(result + offset, tokens, mask=mask)
|
||||
|
||||
|
||||
def get_last_loc_triton(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
BLOCK_SIZE = 256
|
||||
num_tokens = prefix_lens_tensor.shape[0]
|
||||
result = torch.empty_like(prefix_lens_tensor)
|
||||
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
||||
|
||||
get_last_loc_kernel[grid](
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token.stride(0),
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def test_get_last_loc():
|
||||
max_batch = 4097
|
||||
max_context_len = 6148
|
||||
batch_size = 20
|
||||
|
||||
# Initialize input tensors
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
|
||||
pre_lens = torch.randint(
|
||||
-max_context_len // 2,
|
||||
max_context_len,
|
||||
(batch_size,),
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
|
||||
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(last_loc_res, last_loc_ref)
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=batch_sizes,
|
||||
line_arg="provider",
|
||||
line_vals=["reference", "triton"],
|
||||
line_names=["PyTorch", "Triton"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="get-last-loc-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
max_batch = 2048
|
||||
max_context_len = 16384
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
|
||||
pre_lens = torch.randint(
|
||||
-max_context_len // 2,
|
||||
max_context_len,
|
||||
(batch_size,),
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "reference":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
|
||||
"""Run benchmark and save results"""
|
||||
|
||||
# Ensure save path exists
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Run correctness test
|
||||
test_get_last_loc()
|
||||
print("Correctness test passed!")
|
||||
|
||||
# Run performance test
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/get_last_loc/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_benchmark(args.save_path)
|
||||
@@ -0,0 +1,342 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid)
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < (seq_len - pre_len)
|
||||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||
tl.store(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ offset
|
||||
+ pre_len,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton_optimize(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_token = tl.program_id(1)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid_batch)
|
||||
pre_len = tl.load(pre_lens + pid_batch)
|
||||
seq_len = tl.load(seq_lens + pid_batch)
|
||||
extend_len = seq_len - pre_len
|
||||
|
||||
cumsum_start = 0
|
||||
for i in range(pid_batch):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
token_start = pid_token * BLOCK_SIZE
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
actual_offset = token_start + offset
|
||||
mask = actual_offset < extend_len
|
||||
|
||||
src_ptr = out_cache_loc + cumsum_start + actual_offset
|
||||
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||
value = tl.load(src_ptr, mask=mask)
|
||||
dst_ptr = (
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ actual_offset
|
||||
+ pre_len
|
||||
)
|
||||
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||
|
||||
tl.store(dst_ptr, value, mask=mask)
|
||||
|
||||
|
||||
def write_req_to_token_pool_reference(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
pre_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
extend_lens: torch.Tensor,
|
||||
out_cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
"""Reference implementation using PyTorch"""
|
||||
for i in range(len(req_pool_indices)):
|
||||
req_pool_idx = req_pool_indices[i].item()
|
||||
pre_len = pre_lens[i].item()
|
||||
seq_len = seq_lens[i].item()
|
||||
extend_len = extend_lens[i].item()
|
||||
|
||||
cumsum_start = sum(extend_lens[:i].tolist())
|
||||
|
||||
# Copy values from out_cache_loc to req_to_token
|
||||
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
|
||||
cumsum_start : cumsum_start + extend_len
|
||||
]
|
||||
|
||||
|
||||
def test_write_req_to_token_pool():
|
||||
max_batch = 4097
|
||||
max_context_len = 6148
|
||||
batch_size = 1
|
||||
extend_len = 14
|
||||
|
||||
# Initialize input tensors
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
|
||||
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
|
||||
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
# Create copies for reference implementation
|
||||
req_to_token_ref = req_to_token.clone()
|
||||
req_to_token_opt = req_to_token.clone()
|
||||
|
||||
# Run original triton kernel
|
||||
write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
# Run optimized triton kernel
|
||||
def grid(batch_size, extend_len):
|
||||
num_token_blocks = triton.cdiv(extend_len, 512)
|
||||
return (batch_size, num_token_blocks)
|
||||
|
||||
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
|
||||
req_to_token_opt,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
write_req_to_token_pool_reference(
|
||||
req_to_token_ref,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||
|
||||
# Test case 2: batch size > 1
|
||||
batch_size = 3
|
||||
extend_lens_list = [14, 20, 30]
|
||||
total_extend_len = sum(extend_lens_list)
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
|
||||
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
|
||||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
req_to_token_ref = req_to_token.clone()
|
||||
req_to_token_opt = req_to_token.clone()
|
||||
|
||||
# Run original triton kernel
|
||||
write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
# Run optimized triton kernel
|
||||
max_extend_len = max(extend_lens_list)
|
||||
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
|
||||
req_to_token_opt,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
write_req_to_token_pool_reference(
|
||||
req_to_token_ref,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
|
||||
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
configs = list(itertools.product(batch_sizes, extend_lens))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "extend_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["reference", "triton", "triton_optimize"],
|
||||
line_names=["PyTorch", "Triton", "Triton Optimized"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="write-req-to-token-pool-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, extend_len, provider):
|
||||
max_batch = 256
|
||||
max_context_len = 16384
|
||||
|
||||
extend_lens_list = [extend_len] * batch_size
|
||||
total_extend_len = sum(extend_lens_list)
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
|
||||
seq_lens = pre_lens + extend_len
|
||||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "reference":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: write_req_to_token_pool_reference(
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
|
||||
def run_optimized():
|
||||
block_size = 128 if extend_len <= 1024 else 512
|
||||
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
|
||||
write_req_to_token_pool_triton_optimize[grid_config](
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_optimized, quantiles=quantiles
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
|
||||
"""Run benchmark and save results"""
|
||||
|
||||
# Ensure save path exists
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Run correctness test
|
||||
test_write_req_to_token_pool()
|
||||
print("Correctness test passed!")
|
||||
|
||||
# Run performance test
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/write_req_to_token_pool/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_benchmark(args.save_path)
|
||||
@@ -0,0 +1,283 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton.testing as tt
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||
|
||||
|
||||
def extend_attention_fwd_torch(
|
||||
q: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||
k: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||
v: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||
o: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||
qo_indptr: torch.Tensor, # [B+1]
|
||||
kv_indptr: torch.Tensor, # [B+1]
|
||||
kv_indices: torch.Tensor, # [prefix_tokens]
|
||||
sliding_window_size: int,
|
||||
):
|
||||
B = qo_indptr.size(0) - 1
|
||||
_, H_Q, D = q.shape
|
||||
_, H_KV, _ = k.shape
|
||||
|
||||
group_size = H_Q // H_KV
|
||||
scale = 1.0 / D**0.5
|
||||
|
||||
for i in range(B):
|
||||
q_start = int(qo_indptr[i].item())
|
||||
q_end = int(qo_indptr[i + 1].item())
|
||||
kv_start = int(kv_indptr[i].item())
|
||||
kv_end = int(kv_indptr[i + 1].item())
|
||||
|
||||
prefix_indices = kv_indices[kv_start:kv_end]
|
||||
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||
|
||||
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
|
||||
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
|
||||
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
|
||||
|
||||
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
|
||||
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
|
||||
|
||||
if group_size != 1:
|
||||
k_full_hq = k_full.repeat_interleave(
|
||||
group_size, dim=1
|
||||
) # [total_len, H_Q, D]
|
||||
v_full_hq = v_full.repeat_interleave(
|
||||
group_size, dim=1
|
||||
) # [total_len, H_Q, D]
|
||||
else:
|
||||
k_full_hq = k_full
|
||||
v_full_hq = v_full
|
||||
|
||||
prefix_len = k_prefix.size(0)
|
||||
extend_len = k_extend.size(0)
|
||||
total_len = prefix_len + extend_len
|
||||
|
||||
# causal
|
||||
pos_keys = torch.arange(total_len, device=q.device)
|
||||
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
|
||||
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
|
||||
|
||||
# sliding window
|
||||
if sliding_window_size is not None and sliding_window_size > 0:
|
||||
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
|
||||
else:
|
||||
start = torch.zeros_like(t)
|
||||
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
|
||||
|
||||
final_mask = causal_mask & window_mask
|
||||
|
||||
attn_scores = (
|
||||
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
|
||||
) # [extend_len, H_Q, total_len]
|
||||
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
|
||||
|
||||
|
||||
def _build_batch(
|
||||
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
|
||||
):
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||
|
||||
kv_indices = torch.zeros(
|
||||
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
|
||||
)
|
||||
for i in range(B):
|
||||
s = kv_indptr[i].item()
|
||||
e = kv_indptr[i + 1].item()
|
||||
kv_indices[s:e] = torch.arange(
|
||||
b_start_loc[i],
|
||||
b_start_loc[i] + b_seq_len_prefix[i],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
total_token_num = int(torch.sum(b_seq_len).item())
|
||||
extend_token_num = int(torch.sum(b_seq_len_extend).item())
|
||||
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend_triton = torch.empty(
|
||||
(extend_token_num, H_Q, D), dtype=dtype, device=device
|
||||
)
|
||||
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
|
||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||
|
||||
inputs = dict(
|
||||
q_extend=q_extend,
|
||||
k_extend=k_extend,
|
||||
v_extend=v_extend,
|
||||
k_buffer=k_buffer,
|
||||
v_buffer=v_buffer,
|
||||
o_extend_triton=o_extend_triton,
|
||||
o_extend_torch=o_extend_torch,
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
kv_indices=kv_indices,
|
||||
max_len_extend=max_len_extend,
|
||||
WINDOW_SIZE=WINDOW_SIZE,
|
||||
)
|
||||
meta = dict(
|
||||
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
|
||||
)
|
||||
return inputs, meta
|
||||
|
||||
|
||||
def _run_triton(inputs):
|
||||
extend_attention_fwd(
|
||||
inputs["q_extend"],
|
||||
inputs["k_extend"],
|
||||
inputs["v_extend"],
|
||||
inputs["o_extend_triton"],
|
||||
inputs["k_buffer"],
|
||||
inputs["v_buffer"],
|
||||
inputs["qo_indptr"],
|
||||
inputs["kv_indptr"],
|
||||
inputs["kv_indices"],
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=inputs["max_len_extend"],
|
||||
sliding_window_size=inputs["WINDOW_SIZE"],
|
||||
)
|
||||
|
||||
|
||||
def _run_torch_ref(inputs):
|
||||
extend_attention_fwd_torch(
|
||||
inputs["q_extend"],
|
||||
inputs["k_extend"],
|
||||
inputs["v_extend"],
|
||||
inputs["o_extend_torch"],
|
||||
inputs["k_buffer"],
|
||||
inputs["v_buffer"],
|
||||
inputs["qo_indptr"],
|
||||
inputs["kv_indptr"],
|
||||
inputs["kv_indices"],
|
||||
inputs["WINDOW_SIZE"],
|
||||
)
|
||||
|
||||
|
||||
N_CTXS = [1024, 2048, 4096, 8192]
|
||||
WINDOW_SIZES = [-1, 127, 256, 512]
|
||||
|
||||
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
|
||||
|
||||
PROVIDERS = ["torch", "triton"]
|
||||
|
||||
|
||||
@tt.perf_report(
|
||||
tt.Benchmark(
|
||||
x_names=["N_CTX", "WINDOW_SIZE"],
|
||||
x_vals=CONFIGS,
|
||||
line_arg="provider",
|
||||
line_vals=PROVIDERS,
|
||||
line_names=PROVIDERS,
|
||||
ylabel="Runtime (ms)",
|
||||
plot_name="extend_attention_triton_vs_torch",
|
||||
args={
|
||||
"B": 32,
|
||||
"H_Q": 64,
|
||||
"H_KV": 8,
|
||||
"D": 128,
|
||||
"dtype": "bf16",
|
||||
"device": "cuda",
|
||||
"check_correctness": False,
|
||||
"warmup": 25,
|
||||
"rep": 100,
|
||||
},
|
||||
)
|
||||
)
|
||||
def bench(
|
||||
N_CTX,
|
||||
provider,
|
||||
B,
|
||||
H_Q,
|
||||
H_KV,
|
||||
D,
|
||||
dtype,
|
||||
device,
|
||||
WINDOW_SIZE,
|
||||
check_correctness,
|
||||
warmup,
|
||||
rep,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
dt = dtype_map[dtype]
|
||||
|
||||
inputs, _ = _build_batch(
|
||||
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
|
||||
)
|
||||
|
||||
if check_correctness and provider == "triton":
|
||||
_run_triton(inputs)
|
||||
_run_torch_ref(inputs)
|
||||
torch.cuda.synchronize()
|
||||
if not torch.allclose(
|
||||
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
|
||||
):
|
||||
raise AssertionError("Mismatch between triton and torch reference.")
|
||||
|
||||
if provider == "triton":
|
||||
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
|
||||
elif provider == "torch":
|
||||
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
|
||||
else:
|
||||
raise ValueError(provider)
|
||||
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench.run(print_data=True, show_plots=False)
|
||||
Reference in New Issue
Block a user